in torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp [587:746]
std::tuple<at::Tensor, at::Tensor, at::Tensor> backward_gradient_inputs(
at::Tensor input,
at::Tensor weight,
at::Tensor offset,
at::Tensor mask,
at::Tensor grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs,
bool use_mask) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_h =
(in_h + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
long out_w =
(in_w + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
auto grad_input = at::zeros_like(input);
auto grad_offset = at::zeros_like(offset);
auto grad_mask = at::zeros_like(mask);
if (batch_sz == 0) {
return std::make_tuple(grad_input, grad_offset, grad_mask);
}
auto columns = at::empty(
{n_in_channels * weight_w * weight_h, n_parallel_imgs * out_h * out_w},
input.options());
// Separate into blocks
grad_input = grad_input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
grad_offset = grad_offset.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
offset = offset.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
grad_mask = grad_mask.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
mask = mask.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_out = grad_out
.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w})
.permute({0, 2, 3, 1, 4, 5});
weight = weight.reshape(
{n_weight_grps,
weight.size(0) / n_weight_grps,
weight.size(1),
weight.size(2),
weight.size(3)});
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
columns.zero_();
// Separate into weight groups
for (int g = 0; g < n_weight_grps; g++) {
columns[g] = columns[g].addmm_(
weight[g].flatten(1).transpose(0, 1), grad_out[elt][g].flatten(1));
}
compute_grad_offset_and_mask(
columns,
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
grad_offset[elt],
grad_mask[elt]);
compute_grad_input(
columns,
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
grad_input[elt]);
}
grad_input = grad_input.view({batch_sz, n_in_channels, in_h, in_w});
grad_offset = grad_offset.view(
{batch_sz, n_offset_grps * 2 * weight_h * weight_w, out_h, out_w});
if (use_mask) {
grad_mask = grad_mask.view(
{batch_sz, n_offset_grps * weight_h * weight_w, out_h, out_w});
}
return std::make_tuple(grad_input, grad_offset, grad_mask);
}