in torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp [748:865]
at::Tensor backward_gradient_parameters(
at::Tensor input,
const at::Tensor& weight,
at::Tensor offset,
at::Tensor mask,
const at::Tensor& grad_out,
int stride_h,
int stride_w,
int pad_h,
int pad_w,
int dilation_h,
int dilation_w,
int n_weight_grps,
int n_offset_grps,
int n_parallel_imgs,
bool use_mask) {
int batch_sz = input.size(0);
int n_in_channels = input.size(1);
int in_h = input.size(2);
int in_w = input.size(3);
n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);
long n_out_channels = weight.size(0);
int weight_h = weight.size(2);
int weight_w = weight.size(3);
long out_h = grad_out.size(2);
long out_w = grad_out.size(3);
auto grad_weight = at::zeros_like(weight);
if (batch_sz == 0) {
return grad_weight;
}
at::Tensor grad_out_buf = grad_out
.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_weight_grps,
n_out_channels / n_weight_grps,
out_h,
out_w})
.permute({0, 2, 3, 1, 4, 5})
.contiguous();
input = input.reshape(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset = offset.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask = mask.reshape(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
grad_weight = grad_weight.view(
{n_weight_grps,
grad_weight.size(0) / n_weight_grps,
grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3)});
auto columns = at::empty(
{n_weight_grps,
n_in_channels * weight_w * weight_h / n_weight_grps,
n_parallel_imgs * out_h * out_w},
input.options());
for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
deformable_im2col(
input[elt],
offset[elt],
mask[elt],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
for (int g = 0; g < n_weight_grps; g++) {
grad_weight[g] =
grad_weight[g]
.flatten(1)
.addmm_(
grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
.view_as(grad_weight[g]);
}
}
grad_weight = grad_weight.view(
{grad_weight.size(0) * grad_weight.size(1),
grad_weight.size(2),
grad_weight.size(3),
grad_weight.size(4)});
return grad_weight;
}