in tensorflow_addons/custom_ops/layers/cc/kernels/correlation_cost_op_gpu.cu.cc [375:456]
Status operator()(OpKernelContext *context, const Tensor &input_a_t,
const Tensor &input_b_t, const Tensor &topdiff_t,
Tensor *output_a_gradient_t, Tensor *output_b_gradient_t,
/* params */
int kernel_size, int max_displacement, int stride_1,
int stride_2, int pad, TensorFormat data_format) {
// do not change: the CUDA kernels expects THREADS_PER_BLOCK==32
const int THREADS_PER_BLOCK = 32;
const int32 N = GetTensorDim(input_a_t, data_format, 'N');
const int32 iC = GetTensorDim(input_a_t, data_format, 'C');
const int32 iH = GetTensorDim(input_a_t, data_format, 'H');
const int32 iW = GetTensorDim(input_a_t, data_format, 'W');
Tensor padded_a_t;
Tensor padded_b_t;
TensorShape padded_shape({N, iH + 2 * pad, iW + 2 * pad, iC});
TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<Dtype>::value,
padded_shape, &padded_a_t));
TF_RETURN_IF_ERROR(context->allocate_temp(DataTypeToEnum<Dtype>::value,
padded_shape, &padded_b_t));
dim3 blocks_grid(N, iH, iW);
dim3 threads_block(THREADS_PER_BLOCK);
// topdiff is NCHW
const int32 oC = GetTensorDim(topdiff_t, FORMAT_NCHW, 'C');
const int32 oH = GetTensorDim(topdiff_t, FORMAT_NCHW, 'H');
const int32 oW = GetTensorDim(topdiff_t, FORMAT_NCHW, 'W');
// set everything to zero (we zero-pad)
cudaMemset(padded_a_t.flat<Dtype>().data(), 0,
padded_a_t.NumElements() * sizeof(Dtype));
cudaMemset(padded_b_t.flat<Dtype>().data(), 0,
padded_b_t.NumElements() * sizeof(Dtype));
cudaMemset(output_a_gradient_t->flat<Dtype>().data(), 0,
output_a_gradient_t->NumElements() * sizeof(Dtype));
cudaMemset(output_b_gradient_t->flat<Dtype>().data(), 0,
output_b_gradient_t->NumElements() * sizeof(Dtype));
const bool is_NCHW = (data_format == FORMAT_NCHW);
if (is_NCHW) {
pad_and_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
input_a_t.flat<Dtype>().data(), padded_a_t.flat<Dtype>().data(), iC,
iH, iW, pad);
pad_and_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
input_b_t.flat<Dtype>().data(), padded_b_t.flat<Dtype>().data(), iC,
iH, iW, pad);
} else {
pad_and_no_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
input_a_t.flat<Dtype>().data(), padded_a_t.flat<Dtype>().data(), iC,
iH, iW, pad);
pad_and_no_transpose<THREADS_PER_BLOCK><<<blocks_grid, threads_block>>>(
input_b_t.flat<Dtype>().data(), padded_b_t.flat<Dtype>().data(), iC,
iH, iW, pad);
}
const GPUDevice &d = context->eigen_gpu_device();
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(iH, iW, iC);
for (int n = 0; n < N; ++n) {
Correlation_backward_input1<THREADS_PER_BLOCK>
<<<totalBlocksCorr, threadsPerBlock>>>(
n, output_a_gradient_t->flat<Dtype>().data(), iC, iH, iW,
topdiff_t.flat<Dtype>().data(), oC, oH, oW,
padded_b_t.flat<Dtype>().data(), pad, kernel_size,
max_displacement, stride_1, stride_2, is_NCHW);
}
for (int n = 0; n < N; n++) {
Correlation_backward_input2<THREADS_PER_BLOCK>
<<<totalBlocksCorr, threadsPerBlock>>>(
n, output_b_gradient_t->flat<Dtype>().data(), iC, iH, iW,
topdiff_t.flat<Dtype>().data(), oC, oH, oW,
padded_a_t.flat<Dtype>().data(), pad, kernel_size,
max_displacement, stride_1, stride_2, is_NCHW);
}
return Status::OK();
}