in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [140:218]
__global__ void Correlation_backward_input1(
int item, float *__restrict__ gradInput1, int Cin, int Hin, int Win,
const float *__restrict__ gradOutput, int Cout, int Hout, int Wout,
const float *__restrict__ rInput2, int pad_size, int kernel_size,
int max_displacement, int stride1, int stride2, bool is_NCHW) {
const int n = item;
const int h = blockIdx.x * stride1 + pad_size;
const int w = blockIdx.y * stride1 + pad_size;
const int c = blockIdx.z;
const int t0 = threadIdx.x;
const int kernel_rad = (kernel_size - 1) / 2;
const int displacement_rad = max_displacement / stride2;
const int displacement_size = 2 * displacement_rad + 1;
int Wmin = (w - kernel_rad - max_displacement) / stride1;
int Hmin = (h - kernel_rad - max_displacement) / stride1;
int Wmax = (w + kernel_rad - max_displacement) / stride1;
int Hmax = (h + kernel_rad - max_displacement) / stride1;
if (Wmax < 0 || Hmax < 0 || Wmin >= Wout || Hmin >= Hout) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
if (Wmin > Wmax || Hmin > Hmax) {
// assumes gradInput1 is pre-allocated and zero filled
return;
}
Wmin = max(0, Wmin);
Wmax = min(Wout - 1, Wmax);
Hmin = max(0, Hmin);
Hmax = min(Hout - 1, Hmax);
const int pWin = Win + 2 * pad_size;
const int pHin = Hin + 2 * pad_size;
const float nelems = kernel_size * kernel_size * Cin;
typedef cub::WarpReduce<float> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_sum_storage;
float thread_accumulation = 0;
for (int tc = t0; tc < Cout; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
const int indx2 =
n * (pHin * pWin * Cin) + (h + j2) * (pWin * Cin) + (w + i2) * Cin + c;
const float val2 = ldg(rInput2 + indx2);
for (int j = Hmin; j <= Hmax; ++j) {
for (int i = Wmin; i <= Wmax; ++i) {
const int tindx =
n * (Cout * Hout * Wout) + tc * (Hout * Wout) + j * Wout + i;
thread_accumulation += ldg(gradOutput + tindx) * val2;
}
}
}
__syncthreads();
// THREADS_PER_BLOCK==32, hence there is only one warp per block
const float reduce_sum =
WarpReduce(temp_sum_storage).Sum(thread_accumulation);
if (t0 == 0) {
if (is_NCHW) {
const int indx1 = n * (Cin * Hin * Win) + c * (Hin * Win) +
(h - pad_size) * Win + (w - pad_size);
gradInput1[indx1] = reduce_sum / nelems;
} else {
const int indx1 = n * (Cin * Hin * Win) + (h - pad_size) * (Win * Cin) +
(w - pad_size) * Cin + c;
gradInput1[indx1] = reduce_sum / nelems;
}
}
}