in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc [209:327]
void operator()(OpKernelContext* ctx, const CPUDevice& d,
const T* __restrict__ data, const T* __restrict__ warp,
const T* __restrict__ grad_output, T* __restrict__ grad_data,
T* __restrict__ grad_warp, const int batch_size,
const int data_height, const int data_width,
const int data_channels, const int num_sampling_points) {
// Set gradients to 0, because the kernel incrementally updates the
// tensor entries by adding partial contributions.
const int resampler_output_size =
batch_size * num_sampling_points * data_channels;
const int grad_warp_size = resampler_output_size / data_channels * 2;
const int grad_data_size =
data_height * data_width * data_channels * batch_size;
memset(static_cast<void*>(grad_data), 0, sizeof(T) * grad_data_size);
memset(static_cast<void*>(grad_warp), 0, sizeof(T) * grad_warp_size);
const auto&& data_batch_stride = data_height * data_width * data_channels;
const auto&& warp_batch_stride = num_sampling_points * 2;
const int output_batch_stride = num_sampling_points * data_channels;
const T zero = static_cast<T>(0.0);
const T one = static_cast<T>(1.0);
auto update_grads_for_batches = [&](const int start, const int limit) {
for (int batch_id = start; batch_id < limit; ++batch_id) {
// Utility lambdas to access data and update gradient tensors.
// The functions take care of performing the relevant pointer
// arithmetics abstracting away the low level details in the
// main loop over samples. Note that data is stored in NHWC format.
auto get_data_point = [&](const int x, const int y, const int chan) {
const bool point_is_in_range =
(x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
return point_is_in_range
? data[batch_id * data_batch_stride +
data_channels * (y * data_width + x) + chan]
: zero;
};
auto update_grad_data = [&](const int x, const int y, const int chan,
const T value) {
const bool point_is_in_range =
(x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
if (point_is_in_range) {
grad_data[batch_id * data_batch_stride +
data_channels * (y * data_width + x) + chan] += value;
}
};
auto update_grad_warp = [&](const int sample_id, const int channel,
const T value) {
grad_warp[batch_id * warp_batch_stride + sample_id * 2 + channel] +=
value;
};
for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
// The interpolation function whose gradient this function implements:
// a) implicitly pads the input data with 0s (hence the unusual checks
// with {x,y} > -1)
// b) returns 0 when sampling outside the (padded) image.
// The effect is that the sampled signal smoothly goes to 0 outside
// the original input domain, rather than presenting a jump
// discontinuity at the image boundaries.
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
x < static_cast<T>(data_width) &&
y < static_cast<T>(data_height)) {
// Precompute floor (f) and ceil (c) values for x and y.
const int fx = std::floor(static_cast<float>(x));
const int fy = std::floor(static_cast<float>(y));
const int cx = fx + 1;
const int cy = fy + 1;
const T dx = static_cast<T>(cx) - x;
const T dy = static_cast<T>(cy) - y;
for (int chan = 0; chan < data_channels; ++chan) {
const T grad_output_value =
grad_output[batch_id * output_batch_stride +
sample_id * data_channels + chan];
const T img_fxfy = get_data_point(fx, fy, chan);
const T img_cxcy = get_data_point(cx, cy, chan);
const T img_fxcy = get_data_point(fx, cy, chan);
const T img_cxfy = get_data_point(cx, fy, chan);
// Update partial gradients wrt relevant warp field entries
update_grad_warp(
sample_id, 0,
grad_output_value * ((one - dy) * (img_cxcy - img_fxcy) +
dy * (img_cxfy - img_fxfy)));
update_grad_warp(
sample_id, 1,
grad_output_value * ((one - dx) * (img_cxcy - img_cxfy) +
dx * (img_fxcy - img_fxfy)));
// Update partial gradients wrt sampled data
update_grad_data(fx, fy, chan, grad_output_value * dx * dy);
update_grad_data(cx, cy, chan,
grad_output_value * (one - dx) * (one - dy));
update_grad_data(fx, cy, chan,
grad_output_value * dx * (one - dy));
update_grad_data(cx, fy, chan,
grad_output_value * (one - dx) * dy);
}
}
}
}
};
// Rough estimate of work for each batch entry.
// From third_party/tensorflow/core/util/work_sharder.cc we gather that an
// estimate of the cost of each work unit is needed to correctly shard the
// workload. thread_pool->ParallelFor assumes each cost unit is 1ns, minimum
// cost per shard
// being 10us.
// TODO(fviola): Check out if there is a better way of doing this.
auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
const int64 cost =
static_cast<int64>(num_sampling_points) * data_channels * 1000;
thread_pool->ParallelFor(batch_size, cost, update_grads_for_batches);
}