void operator()

in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc [42:122]


  void operator()(OpKernelContext* ctx, const CPUDevice& d,
                  const T* __restrict__ data, const T* __restrict__ warp,
                  T* __restrict__ output, const int batch_size,
                  const int data_height, const int data_width,
                  const int data_channels, const int num_sampling_points) {
    const int warp_batch_stride = num_sampling_points * 2;
    const int data_batch_stride = data_height * data_width * data_channels;
    const int output_batch_stride = num_sampling_points * data_channels;
    const T zero = static_cast<T>(0.0);
    const T one = static_cast<T>(1.0);

    auto resample_batches = [&](const int start, const int limit) {
      for (int batch_id = start; batch_id < limit; ++batch_id) {
        // Utility lambda to access data point and set output values.
        // The functions take care of performing the relevant pointer
        // arithmetics abstracting away the low level details in the
        // main loop over samples. Note that data is stored in NHWC format.
        auto set_output = [&](const int sample_id, const int channel,
                              const T value) {
          output[batch_id * output_batch_stride + sample_id * data_channels +
                 channel] = value;
        };

        auto get_data_point = [&](const int x, const int y, const int chan) {
          const bool point_is_in_range =
              (x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
          return point_is_in_range
                     ? data[batch_id * data_batch_stride +
                            data_channels * (y * data_width + x) + chan]
                     : zero;
        };

        for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
          const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
          const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
          // The interpolation function:
          // a) implicitly pads the input data with 0s (hence the unusual checks
          // with {x,y} > -1)
          // b) returns 0 when sampling outside the (padded) image.
          // The effect is that the sampled signal smoothly goes to 0 outside
          // the original input domain, rather than presenting a jump
          // discontinuity at the image boundaries.
          if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
              x < static_cast<T>(data_width) &&
              y < static_cast<T>(data_height)) {
            // Precompute floor (f) and ceil (c) values for x and y.
            const int fx = std::floor(static_cast<float>(x));
            const int fy = std::floor(static_cast<float>(y));
            const int cx = fx + 1;
            const int cy = fy + 1;
            const T dx = static_cast<T>(cx) - x;
            const T dy = static_cast<T>(cy) - y;

            for (int chan = 0; chan < data_channels; ++chan) {
              const T img_fxfy = dx * dy * get_data_point(fx, fy, chan);
              const T img_cxcy =
                  (one - dx) * (one - dy) * get_data_point(cx, cy, chan);
              const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan);
              const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan);
              set_output(sample_id, chan,
                         img_fxfy + img_cxcy + img_fxcy + img_cxfy);
            }
          } else {
            for (int chan = 0; chan < data_channels; ++chan) {
              set_output(sample_id, chan, zero);
            }
          }
        }
      }
    };
    // Rough estimate of work for each batch entry.
    // From third_party/tensorflow/core/util/work_sharder.cc we gather that an
    // estimate of the cost of each work unit is needed to correctly shard the
    // workload. thread_pool->ParallelFor assumes each cost unit is 1ns, minimum
    // cost per shard
    // being 10us.
    const int64 cost =
        static_cast<int64>(num_sampling_points) * data_channels * 1000;
    auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
    thread_pool->ParallelFor(batch_size, cost, resample_batches);
  }