void operator()

in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc [209:327]


  void operator()(OpKernelContext* ctx, const CPUDevice& d,
                  const T* __restrict__ data, const T* __restrict__ warp,
                  const T* __restrict__ grad_output, T* __restrict__ grad_data,
                  T* __restrict__ grad_warp, const int batch_size,
                  const int data_height, const int data_width,
                  const int data_channels, const int num_sampling_points) {
    // Set gradients to 0, because the kernel incrementally updates the
    // tensor entries by adding partial contributions.
    const int resampler_output_size =
        batch_size * num_sampling_points * data_channels;
    const int grad_warp_size = resampler_output_size / data_channels * 2;
    const int grad_data_size =
        data_height * data_width * data_channels * batch_size;
    memset(static_cast<void*>(grad_data), 0, sizeof(T) * grad_data_size);
    memset(static_cast<void*>(grad_warp), 0, sizeof(T) * grad_warp_size);

    const auto&& data_batch_stride = data_height * data_width * data_channels;
    const auto&& warp_batch_stride = num_sampling_points * 2;
    const int output_batch_stride = num_sampling_points * data_channels;
    const T zero = static_cast<T>(0.0);
    const T one = static_cast<T>(1.0);

    auto update_grads_for_batches = [&](const int start, const int limit) {
      for (int batch_id = start; batch_id < limit; ++batch_id) {
        // Utility lambdas to access data and update gradient tensors.
        // The functions take care of performing the relevant pointer
        // arithmetics abstracting away the low level details in the
        // main loop over samples. Note that data is stored in NHWC format.
        auto get_data_point = [&](const int x, const int y, const int chan) {
          const bool point_is_in_range =
              (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
          return point_is_in_range
                     ? data[batch_id * data_batch_stride +
                            data_channels * (y * data_width + x) + chan]
                     : zero;
        };

        auto update_grad_data = [&](const int x, const int y, const int chan,
                                    const T value) {
          const bool point_is_in_range =
              (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
          if (point_is_in_range) {
            grad_data[batch_id * data_batch_stride +
                      data_channels * (y * data_width + x) + chan] += value;
          }
        };

        auto update_grad_warp = [&](const int sample_id, const int channel,
                                    const T value) {
          grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] +=
              value;
        };

        for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
          const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
          const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
          // The interpolation function whose gradient this function implements:
          // a) implicitly pads the input data with 0s (hence the unusual checks
          // with {x,y} > -1)
          // b) returns 0 when sampling outside the (padded) image.
          // The effect is that the sampled signal smoothly goes to 0 outside
          // the original input domain, rather than presenting a jump
          // discontinuity at the image boundaries.
          if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
              x < static_cast<T>(data_width) &&
              y < static_cast<T>(data_height)) {
            // Precompute floor (f) and ceil (c) values for x and y.
            const int fx = std::floor(static_cast<float>(x));
            const int fy = std::floor(static_cast<float>(y));
            const int cx = fx + 1;
            const int cy = fy + 1;
            const T dx = static_cast<T>(cx) - x;
            const T dy = static_cast<T>(cy) - y;

            for (int chan = 0; chan < data_channels; ++chan) {
              const T grad_output_value =
                  grad_output[batch_id * output_batch_stride +
                              sample_id * data_channels + chan];
              const T img_fxfy = get_data_point(fx, fy, chan);
              const T img_cxcy = get_data_point(cx, cy, chan);
              const T img_fxcy = get_data_point(fx, cy, chan);
              const T img_cxfy = get_data_point(cx, fy, chan);

              // Update partial gradients wrt relevant warp field entries
              update_grad_warp(
                  sample_id, 0,
                  grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
                                       dy * (img_cxfy - img_fxfy)));

              update_grad_warp(
                  sample_id, 1,
                  grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
                                       dx * (img_fxcy - img_fxfy)));

              // Update partial gradients wrt sampled data
              update_grad_data(fx, fy, chan, grad_output_value * dx * dy);
              update_grad_data(cx, cy, chan,
                               grad_output_value * (one - dx) * (one - dy));
              update_grad_data(fx, cy, chan,
                               grad_output_value * dx * (one - dy));
              update_grad_data(cx, fy, chan,
                               grad_output_value * (one - dx) * dy);
            }
          }
        }
      }
    };
    // Rough estimate of work for each batch entry.
    // From third_party/tensorflow/core/util/work_sharder.cc we gather that an
    // estimate of the cost of each work unit is needed to correctly shard the
    // workload. thread_pool->ParallelFor assumes each cost unit is 1ns, minimum
    // cost per shard
    // being 10us.
    // TODO(fviola): Check out if there is a better way of doing this.
    auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
    const int64 cost =
        static_cast<int64>(num_sampling_points) * data_channels * 1000;
    thread_pool->ParallelFor(batch_size, cost, update_grads_for_batches);
  }