in torchvision/csrc/ops/cpu/interpolate_aa_kernels.cpp [506:618]
void cpu_upsample_genNd_backward_aa(
const Tensor& grad_input_,
const Tensor& grad_output_,
bool align_corners,
const scale_type& scales) {
TORCH_CHECK(
grad_input_.dtype() == grad_output_.dtype(),
"expected dtype ",
grad_output_.dtype(),
" for `grad_input` but got dtype ",
grad_input_.dtype());
auto grad_output = grad_output_.contiguous();
auto grad_input = grad_input_.contiguous();
auto grad_output_data = grad_output.data_ptr<scalar_t>();
auto grad_input_data = grad_input.data_ptr<scalar_t>();
auto input_sizes = grad_input.sizes().vec();
auto output_sizes = grad_output.sizes().vec();
auto ndim = input_sizes.size();
// treat nbatch and channels as one dimension
int64_t channels = input_sizes[0] * input_sizes[1];
int64_t input_depth = (ndim == 5) ? input_sizes[2] : 1;
int64_t output_depth = (ndim == 5) ? output_sizes[2] : 1;
int64_t input_height = (ndim >= 4) ? input_sizes[ndim - 2] : 1;
int64_t output_height = (ndim >= 4) ? output_sizes[ndim - 2] : 1;
int64_t input_width = input_sizes[ndim - 1];
int64_t output_width = output_sizes[ndim - 1];
int64_t output_slice_size = output_depth * output_height * output_width;
int interp_size = F<int64_t, float>::interp_size;
auto loop2d = [&](int64_t begin, int64_t end) {
const scalar_t height_scale = area_pixel_compute_scale<scalar_t>(
input_height, output_height, align_corners, scales[0]);
const scalar_t width_scale = area_pixel_compute_scale<scalar_t>(
input_width, output_width, align_corners, scales[1]);
auto input_indexr = [=](int64_t c, int64_t h, int64_t w) {
return grad_input_data + c * input_height * input_width +
h * input_width + w;
};
const scalar_t support_h = (height_scale >= 1.0)
? (interp_size * 0.5) * height_scale
: interp_size * 0.5;
const scalar_t support_w = (width_scale >= 1.0)
? (interp_size * 0.5) * width_scale
: interp_size * 0.5;
const int interp_height = (int)ceilf(support_h) * 2 + 1;
const int interp_width = (int)ceilf(support_w) * 2 + 1;
std::vector<scalar_t> wx(interp_width, 0.0);
std::vector<scalar_t> wy(interp_height, 0.0);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
int64_t xmin, ymin;
int64_t xsize, ysize;
auto filter_fn = F<int64_t, scalar_t>::_filter;
for (int64_t oh = 0; oh < output_height; oh++) {
F<int64_t, scalar_t>::_compute_weights_aa(
oh,
input_height,
height_scale,
support_h,
wy.data(),
interp_height,
filter_fn,
ymin,
ysize);
for (int64_t ow = 0; ow < output_width; ow++) {
F<int64_t, scalar_t>::_compute_weights_aa(
ow,
input_width,
width_scale,
support_w,
wx.data(),
interp_width,
filter_fn,
xmin,
xsize);
for (int64_t c = begin; c < end; c++) {
scalar_t grad_output_value =
grad_output_data[c * output_slice_size + oh * output_width + ow];
for (size_t y = 0; y < ysize; y++) {
for (size_t x = 0; x < xsize; x++) {
*input_indexr(c, ymin + y, xmin + x) +=
wx[x] * wy[y] * grad_output_value;
}
}
}
}
}
};
if (ndim == 4) {
// upsample bilinear 2d
at::parallel_for(
0, channels, at::internal::GRAIN_SIZE / output_slice_size / 4, loop2d);
} else {
TORCH_CHECK(false, "Unsupported tensor ndim");
}
if (!grad_input_.is_contiguous()) {
grad_input_.copy_(grad_input);
}
}