in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops_gpu.cu.cc [41:107]
__global__ void Resampler2DKernel(const T* __restrict__ data,
const T* __restrict__ warp,
T* __restrict__ output, const int batch_size,
const int data_height, const int data_width,
const int data_channels,
const int num_sampling_points) {
const int output_data_size = batch_size * num_sampling_points * data_channels;
GPU_1D_KERNEL_LOOP(index, output_data_size) {
const int out_index = index;
// Get (idxSample, channel, point) from the index.
// Use this formula
// index = batch_id * num_sampling_points * num_chans +
// sample_id * num_chans + chan_id,
// with sample_id = [0, ... ,num_sampling_points)
const int data_batch_stride = data_height * data_width * data_channels;
const int warp_batch_stride = num_sampling_points * 2;
const int output_batch_stride = num_sampling_points * data_channels;
const int batch_id = index / output_batch_stride;
const int index_in_batch = index % output_batch_stride;
const int chan = index_in_batch % data_channels;
const int sample_id = index_in_batch / data_channels;
// Get coords of 2D point where data will be resampled
const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
const T zero = static_cast<T>(0.0);
const T one = static_cast<T>(1.0);
// The interpolation function:
// a) implicitly pads the input data with 0s (hence the unusual checks
// with {x,y} > -1)
// b) returns 0 when sampling outside the (padded) image.
// The effect is that the sampled signal smoothly goes to 0 outside
// the original input domain, rather than presenting a jump
// discontinuity at the image boundaries.
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
x < static_cast<T>(data_width) && y < static_cast<T>(data_height)) {
// Precompute floor (f) and ceil (c) values for x and y.
const int fx = std::floor(static_cast<float>(x));
const int fy = std::floor(static_cast<float>(y));
const int cx = fx + 1;
const int cy = fy + 1;
const T dx = static_cast<T>(cx) - x;
const T dy = static_cast<T>(cy) - y;
const T img_fxfy =
(fx >= 0 && fy >= 0) ? dx * dy * GET_DATA_POINT(fx, fy) : zero;
const T img_cxcy = (cx <= data_width - 1 && cy <= data_height - 1)
? (one - dx) * (one - dy) * GET_DATA_POINT(cx, cy)
: zero;
const T img_fxcy = (fx >= 0 && cy <= data_height - 1)
? dx * (one - dy) * GET_DATA_POINT(fx, cy)
: zero;
const T img_cxfy = (cx <= data_width - 1 && fy >= 0)
? (one - dx) * dy * GET_DATA_POINT(cx, fy)
: zero;
output[out_index] = img_fxfy + img_cxcy + img_fxcy + img_cxfy;
} else {
output[out_index] = zero;
}
}
}