in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [221:297]
__global__ void Correlation_backward_input2(
int item, float *__restrict__ gradInput2, int Cin, int Hin, int Win,
const float *__restrict__ gradOutput, int Cout, int Hout, int Wout,
const float *rInput1, int pad_size, int kernel_size, int max_displacement,
int stride1, int stride2, bool is_NCHW) {
const int n = item;
const int h = blockIdx.x * stride1 + pad_size;
const int w = blockIdx.y * stride1 + pad_size;
const int c = blockIdx.z;
const int t0 = threadIdx.x;
const int kernel_rad = (kernel_size - 1) / 2;
const int displacement_rad = max_displacement / stride2;
const int displacement_size = 2 * displacement_rad + 1;
const int pWin = Win + 2 * pad_size;
const int pHin = Hin + 2 * pad_size;
const float nelems = kernel_size * kernel_size * Cin;
typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_sum_storage;
float thread_accumulation = 0;
for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) {
const int i2 = (tc % displacement_size - displacement_rad) * stride2;
const int j2 = (tc / displacement_size - displacement_rad) * stride2;
int Wmin = (w - kernel_rad - max_displacement - i2) / stride1;
int Hmin = (h - kernel_rad - max_displacement - j2) / stride1;
int Wmax = (w + kernel_rad - max_displacement - i2) / stride1;
int Hmax = (h + kernel_rad - max_displacement - j2) / stride1;
if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
if (Wmin > Wmax || Hmin > Hmax) {
// assumes gradInput2 is pre-allocated and zero filled
continue;
}
Wmin = max(0, Wmin);
Wmax = min(Wout - 1, Wmax);
Hmin = max(0, Hmin);
Hmax = min(Hout - 1, Hmax);
const int indx1 =
n * (pHin * pWin * Cin) + (h - j2) * (pWin * Cin) + (w - i2) * Cin + c;
const float val1 = ldg(rInput1 + indx1);
for (int j = Hmin; j <= Hmax; ++j) {
for (int i = Wmin; i <= Wmax; ++i) {
const int tindx =
n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i;
thread_accumulation += ldg(gradOutput + tindx) * val1;
}
}
}
__syncthreads();
const float reduce_sum =
WarpReduce(temp_sum_storage).Sum(thread_accumulation);
if (t0 == 0) {
if (is_NCHW) {
const int indx2 = n * (Cin * Hin * Win) + c * (Hin * Win) +
(h - pad_size) * (Win) + (w - pad_size);
gradInput2[indx2] = reduce_sum / nelems;
} else {
const int indx2 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) +
(w - pad_size) * Cin + c;
gradInput2[indx2] = reduce_sum / nelems;
}
}
}