in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc [144:232]
__global__ void ResamplerGrad2DKernel(
const T* __restrict__ data, const T* __restrict__ warp,
const T* __restrict__ grad_output, T* __restrict__ grad_data,
T* __restrict__ grad_warp, const int batch_size, const int data_height,
const int data_width, const int data_channels,
const int num_sampling_points) {
const int resampler_output_size =
batch_size * num_sampling_points * data_channels;
GPU_1D_KERNEL_LOOP(index, resampler_output_size) {
const int out_index = index;
// Get (idxSample, channel, point) from the index.
// Use this formula
// index = batch_id * num_sampling_points * num_chans +
// sample_id * num_chans + chan_id,
// with sample_id = [0, ... ,num_sampling_points)
const int data_batch_stride = data_height * data_width * data_channels;
const int warp_batch_stride = num_sampling_points * 2;
const int output_batch_stride = num_sampling_points * data_channels;
const int batch_id = index / output_batch_stride;
const int index_in_batch = index % output_batch_stride;
const int chan = index_in_batch % data_channels;
const int sample_id = index_in_batch / data_channels;
// Get coords of 2D point where data will be resampled
const int warp_id_x = batch_id * warp_batch_stride + sample_id * 2;
const int warp_id_y = warp_id_x + 1;
const T x = warp[warp_id_x];
const T y = warp[warp_id_y];
const T zero = static_cast<T>(0.0);
const T one = static_cast<T>(1.0);
// Get grad output
const T grad_output_value = grad_output[out_index];
// The interpolation function whose gradient this kernel implements:
// a) implicitly pads the input data with 0s (hence the unusual checks
// with {x,y} > -1)
// b) returns 0 when sampling outside the (padded) image.
// The effect is that the sampled signal smoothly goes to 0 outside
// the original input domain, rather than presenting a jump
// discontinuity at the image boundaries.
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
// Precompute floor (f) and ceil (c) values for x and y.
const int fx = std::floor(static_cast<float>(x));
const int fy = std::floor(static_cast<float>(y));
const int cx = fx + 1;
const int cy = fy + 1;
const T dx = static_cast<T>(cx) - x;
const T dy = static_cast<T>(cy) - y;
const T img_fxfy = (fx >= 0 && fy >= 0) ? GET_DATA_POINT(fx, fy) : zero;
const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
? GET_DATA_POINT(cx, cy)
: zero;
const T img_fxcy =
(fx >= 0 && cy <= data_height - 1) ? GET_DATA_POINT(fx, cy) : zero;
const T img_cxfy =
(cx <= data_width - 1 && fy >= 0) ? GET_DATA_POINT(cx, fy) : zero;
// Update partial gradients wrt relevant warp field entries
GpuAtomicAdd(grad_warp + warp_id_x,
grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
dy * (img_cxfy - img_fxfy)));
GpuAtomicAdd(grad_warp + warp_id_y,
grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
dx * (img_fxcy - img_fxfy)));
// Update partial gradients wrt sampled data
if (fx >= 0 && fy >= 0) {
UPDATE_GRAD_DATA_POINT(fx, fy, grad_output_value * dx * dy);
}
if (cx <= data_width - 1 && cy <= data_height - 1) {
UPDATE_GRAD_DATA_POINT(cx, cy,
grad_output_value * (one - dx) * (one - dy));
}
if (fx >= 0 && cy <= data_height - 1) {
UPDATE_GRAD_DATA_POINT(fx, cy, grad_output_value * dx * (one - dy));
}
if (cx <= data_width - 1 && fy >= 0) {
UPDATE_GRAD_DATA_POINT(cx, fy, grad_output_value * (one - dx) * dy);
}
}
}
}