__global__ void Correlation_backward_input2()

in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [221:297]


__global__ void Correlation_backward_input2(
    int item, float *__restrict__ gradInput2, int Cin, int Hin, int Win,
    const float *__restrict__ gradOutput, int Cout, int Hout, int Wout,
    const float *rInput1, int pad_size, int kernel_size, int max_displacement,
    int stride1, int stride2, bool is_NCHW) {
  const int n = item;
  const int h = blockIdx.x * stride1 + pad_size;
  const int w = blockIdx.y * stride1 + pad_size;
  const int c = blockIdx.z;
  const int t0 = threadIdx.x;

  const int kernel_rad = (kernel_size - 1) / 2;
  const int displacement_rad = max_displacement / stride2;
  const int displacement_size = 2 * displacement_rad + 1;

  const int pWin = Win + 2 * pad_size;
  const int pHin = Hin + 2 * pad_size;
  const float nelems = kernel_size * kernel_size * Cin;

  typedef cub::WarpReduce<float> WarpReduce;
  __shared__ typename WarpReduce::TempStorage temp_sum_storage;
  float thread_accumulation = 0;

  for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) {
    const int i2 = (tc % displacement_size - displacement_rad) * stride2;
    const int j2 = (tc / displacement_size - displacement_rad) * stride2;

    int Wmin = (w - kernel_rad - max_displacement - i2) / stride1;
    int Hmin = (h - kernel_rad - max_displacement - j2) / stride1;

    int Wmax = (w + kernel_rad - max_displacement - i2) / stride1;
    int Hmax = (h + kernel_rad - max_displacement - j2) / stride1;

    if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) {
      // assumes gradInput2 is pre-allocated and zero filled
      continue;
    }

    if (Wmin > Wmax || Hmin > Hmax) {
      // assumes gradInput2 is pre-allocated and zero filled
      continue;
    }

    Wmin = max(0, Wmin);
    Wmax = min(Wout - 1, Wmax);

    Hmin = max(0, Hmin);
    Hmax = min(Hout - 1, Hmax);

    const int indx1 =
        n * (pHin * pWin * Cin) + (h - j2) * (pWin * Cin) + (w - i2) * Cin + c;
    const float val1 = ldg(rInput1 + indx1);

    for (int j = Hmin; j <= Hmax; ++j) {
      for (int i = Wmin; i <= Wmax; ++i) {
        const int tindx =
            n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i;
        thread_accumulation += ldg(gradOutput + tindx) * val1;
      }
    }
  }
  __syncthreads();

  const float reduce_sum =
      WarpReduce(temp_sum_storage).Sum(thread_accumulation);
  if (t0 == 0) {
    if (is_NCHW) {
      const int indx2 = n * (Cin * Hin * Win) + c * (Hin * Win) +
                        (h - pad_size) * (Win) + (w - pad_size);
      gradInput2[indx2] = reduce_sum / nelems;
    } else {
      const int indx2 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) +
                        (w - pad_size) * Cin + c;
      gradInput2[indx2] = reduce_sum / nelems;
    }
  }
}