__global__ void Correlation_backward_input1()

in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [140:218]


__global__ void Correlation_backward_input1(
    int item, float *__restrict__ gradInput1, int Cin, int Hin, int Win,
    const float *__restrict__ gradOutput, int Cout, int Hout, int Wout,
    const float *__restrict__ rInput2, int pad_size, int kernel_size,
    int max_displacement, int stride1, int stride2, bool is_NCHW) {
  const int n = item;
  const int h = blockIdx.x * stride1 + pad_size;
  const int w = blockIdx.y * stride1 + pad_size;
  const int c = blockIdx.z;
  const int t0 = threadIdx.x;

  const int kernel_rad = (kernel_size - 1) / 2;
  const int displacement_rad = max_displacement / stride2;
  const int displacement_size = 2 * displacement_rad + 1;

  int Wmin = (w - kernel_rad - max_displacement) / stride1;
  int Hmin = (h - kernel_rad - max_displacement) / stride1;

  int Wmax = (w + kernel_rad - max_displacement) / stride1;
  int Hmax = (h + kernel_rad - max_displacement) / stride1;

  if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) {
    // assumes gradInput1 is pre-allocated and zero filled
    return;
  }

  if (Wmin > Wmax || Hmin > Hmax) {
    // assumes gradInput1 is pre-allocated and zero filled
    return;
  }

  Wmin = max(0, Wmin);
  Wmax = min(Wout - 1, Wmax);

  Hmin = max(0, Hmin);
  Hmax = min(Hout - 1, Hmax);

  const int pWin = Win + 2 * pad_size;
  const int pHin = Hin + 2 * pad_size;
  const float nelems = kernel_size * kernel_size * Cin;

  typedef cub::WarpReduce<float> WarpReduce;
  __shared__ typename WarpReduce::TempStorage temp_sum_storage;
  float thread_accumulation = 0;

  for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) {
    int i2 = (tc % displacement_size - displacement_rad) * stride2;
    int j2 = (tc / displacement_size - displacement_rad) * stride2;

    const int indx2 =
        n * (pHin * pWin * Cin) + (h + j2) * (pWin * Cin) + (w + i2) * Cin + c;

    const float val2 = ldg(rInput2 + indx2);

    for (int j = Hmin; j <= Hmax; ++j) {
      for (int i = Wmin; i <= Wmax; ++i) {
        const int tindx =
            n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i;
        thread_accumulation += ldg(gradOutput + tindx) * val2;
      }
    }
  }
  __syncthreads();

  // THREADS_PER_BLOCK==32, hence there is only one warp per block
  const float reduce_sum =
      WarpReduce(temp_sum_storage).Sum(thread_accumulation);
  if (t0 == 0) {
    if (is_NCHW) {
      const int indx1 = n * (Cin * Hin * Win) + c * (Hin * Win) +
                        (h - pad_size) * Win + (w - pad_size);
      gradInput1[indx1] = reduce_sum / nelems;
    } else {
      const int indx1 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) +
                        (w - pad_size) * Cin + c;
      gradInput1[indx1] = reduce_sum / nelems;
    }
  }
}