__global__ void ResamplerGrad2DKernel()

in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc [144:232]


__global__ void ResamplerGrad2DKernel(
    const T* __restrict__ data, const T* __restrict__ warp,
    const T* __restrict__ grad_output, T* __restrict__ grad_data,
    T* __restrict__ grad_warp, const int batch_size, const int data_height,
    const int data_width, const int data_channels,
    const int num_sampling_points) {
  const int resampler_output_size =
      batch_size * num_sampling_points * data_channels;
  GPU_1D_KERNEL_LOOP(index, resampler_output_size) {
    const int out_index = index;

    // Get (idxSample, channel, point) from the index.
    // Use this formula
    //   index = batch_id * num_sampling_points * num_chans +
    //           sample_id * num_chans + chan_id,
    // with sample_id = [0, ... ,num_sampling_points)
    const int data_batch_stride = data_height * data_width * data_channels;
    const int warp_batch_stride = num_sampling_points * 2;
    const int output_batch_stride = num_sampling_points * data_channels;

    const int batch_id = index / output_batch_stride;
    const int index_in_batch = index % output_batch_stride;
    const int chan = index_in_batch % data_channels;
    const int sample_id = index_in_batch / data_channels;

    // Get coords of 2D point where data will be resampled
    const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2;
    const int warp_id_y = warp_id_x + 1;
    const T x = warp[warp_id_x];
    const T y = warp[warp_id_y];
    const T zero = static_cast<T>(0.0);
    const T one = static_cast<T>(1.0);

    // Get grad output
    const T grad_output_value = grad_output[out_index];
    // The interpolation function whose gradient this kernel implements:
    // a) implicitly pads the input data with 0s (hence the unusual checks
    // with {x,y} > -1)
    // b) returns 0 when sampling outside the (padded) image.
    // The effect is that the sampled signal smoothly goes to 0 outside
    // the original input domain, rather than presenting a jump
    // discontinuity at the image boundaries.
    if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
        x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
      // Precompute floor (f) and ceil (c) values for x and y.
      const int fx = std::floor(static_cast<float>(x));
      const int fy = std::floor(static_cast<float>(y));
      const int cx = fx + 1;
      const int cy = fy + 1;
      const T dx = static_cast<T>(cx) - x;
      const T dy = static_cast<T>(cy) - y;

      const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero;

      const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
                             ? GET_DATA_POINT(cx, cy)
                             : zero;

      const T img_fxcy =
          (fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero;

      const T img_cxfy =
          (cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero;

      // Update partial gradients wrt relevant warp field entries
      GpuAtomicAdd(grad_warp + warp_id_x,
                   grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
                                        dy * (img_cxfy - img_fxfy)));
      GpuAtomicAdd(grad_warp + warp_id_y,
                   grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
                                        dx * (img_fxcy - img_fxfy)));

      // Update partial gradients wrt sampled data
      if (fx >= 0 && fy >= 0) {
        UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy);
      }
      if (cx <= data_width - 1 && cy <= data_height - 1) {
        UPDATE_GRAD_DATA_POINT(cx, cy,
                               grad_output_value * (one - dx) * (one - dy));
      }
      if (fx >= 0 && cy <= data_height - 1) {
        UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy));
      }
      if (cx <= data_width - 1 && fy >= 0) {
        UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy);
      }
    }
  }
}