in tensorflow_addons/custom_ops/image/cc/kernels/resampler_ops.cc [42:122]
void operator()(OpKernelContext* ctx, const CPUDevice& d,
const T* __restrict__ data, const T* __restrict__ warp,
T* __restrict__ output, const int batch_size,
const int data_height, const int data_width,
const int data_channels, const int num_sampling_points) {
const int warp_batch_stride = num_sampling_points * 2;
const int data_batch_stride = data_height * data_width * data_channels;
const int output_batch_stride = num_sampling_points * data_channels;
const T zero = static_cast<T>(0.0);
const T one = static_cast<T>(1.0);
auto resample_batches = [&](const int start, const int limit) {
for (int batch_id = start; batch_id < limit; ++batch_id) {
// Utility lambda to access data point and set output values.
// The functions take care of performing the relevant pointer
// arithmetics abstracting away the low level details in the
// main loop over samples. Note that data is stored in NHWC format.
auto set_output = [&](const int sample_id, const int channel,
const T value) {
output[batch_id * output_batch_stride + sample_id * data_channels +
channel] = value;
};
auto get_data_point = [&](const int x, const int y, const int chan) {
const bool point_is_in_range =
(x >= 0 && y >= 0 && x <= data_width - 1 && y <= data_height - 1);
return point_is_in_range
? data[batch_id * data_batch_stride +
data_channels * (y * data_width + x) + chan]
: zero;
};
for (int sample_id = 0; sample_id < num_sampling_points; ++sample_id) {
const T x = warp[batch_id * warp_batch_stride + sample_id * 2];
const T y = warp[batch_id * warp_batch_stride + sample_id * 2 + 1];
// The interpolation function:
// a) implicitly pads the input data with 0s (hence the unusual checks
// with {x,y} > -1)
// b) returns 0 when sampling outside the (padded) image.
// The effect is that the sampled signal smoothly goes to 0 outside
// the original input domain, rather than presenting a jump
// discontinuity at the image boundaries.
if (x > static_cast<T>(-1.0) && y > static_cast<T>(-1.0) &&
x < static_cast<T>(data_width) &&
y < static_cast<T>(data_height)) {
// Precompute floor (f) and ceil (c) values for x and y.
const int fx = std::floor(static_cast<float>(x));
const int fy = std::floor(static_cast<float>(y));
const int cx = fx + 1;
const int cy = fy + 1;
const T dx = static_cast<T>(cx) - x;
const T dy = static_cast<T>(cy) - y;
for (int chan = 0; chan < data_channels; ++chan) {
const T img_fxfy = dx * dy * get_data_point(fx, fy, chan);
const T img_cxcy =
(one - dx) * (one - dy) * get_data_point(cx, cy, chan);
const T img_fxcy = dx * (one - dy) * get_data_point(fx, cy, chan);
const T img_cxfy = (one - dx) * dy * get_data_point(cx, fy, chan);
set_output(sample_id, chan,
img_fxfy + img_cxcy + img_fxcy + img_cxfy);
}
} else {
for (int chan = 0; chan < data_channels; ++chan) {
set_output(sample_id, chan, zero);
}
}
}
}
};
// Rough estimate of work for each batch entry.
// From third_party/tensorflow/core/util/work_sharder.cc we gather that an
// estimate of the cost of each work unit is needed to correctly shard the
// workload. thread_pool->ParallelFor assumes each cost unit is 1ns, minimum
// cost per shard
// being 10us.
const int64 cost =
static_cast<int64>(num_sampling_points) * data_channels * 1000;
auto thread_pool = ctx->device()->tensorflow_cpu_worker_threads()->workers;
thread_pool->ParallelFor(batch_size, cost, resample_batches);
}