at::Tensor backward_gradient_parameters()

in torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp [748:865]


at::Tensor backward_gradient_parameters(
    at::Tensor input,
    const at::Tensor& weight,
    at::Tensor offset,
    at::Tensor mask,
    const at::Tensor& grad_out,
    int stride_h,
    int stride_w,
    int pad_h,
    int pad_w,
    int dilation_h,
    int dilation_w,
    int n_weight_grps,
    int n_offset_grps,
    int n_parallel_imgs,
    bool use_mask) {
  int batch_sz = input.size(0);
  int n_in_channels = input.size(1);
  int in_h = input.size(2);
  int in_w = input.size(3);

  n_parallel_imgs = std::min(batch_sz, n_parallel_imgs);

  long n_out_channels = weight.size(0);
  int weight_h = weight.size(2);
  int weight_w = weight.size(3);

  long out_h = grad_out.size(2);
  long out_w = grad_out.size(3);

  auto grad_weight = at::zeros_like(weight);
  if (batch_sz == 0) {
    return grad_weight;
  }

  at::Tensor grad_out_buf = grad_out
                                .reshape(
                                    {batch_sz / n_parallel_imgs,
                                     n_parallel_imgs,
                                     n_weight_grps,
                                     n_out_channels / n_weight_grps,
                                     out_h,
                                     out_w})
                                .permute({0, 2, 3, 1, 4, 5})
                                .contiguous();

  input = input.reshape(
      {batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});

  offset = offset.reshape(
      {batch_sz / n_parallel_imgs,
       n_parallel_imgs,
       n_offset_grps * 2 * weight_h * weight_w,
       out_h,
       out_w});

  if (use_mask) {
    mask = mask.reshape(
        {batch_sz / n_parallel_imgs,
         n_parallel_imgs,
         n_offset_grps * weight_h * weight_w,
         out_h,
         out_w});
  }

  grad_weight = grad_weight.view(
      {n_weight_grps,
       grad_weight.size(0) / n_weight_grps,
       grad_weight.size(1),
       grad_weight.size(2),
       grad_weight.size(3)});

  auto columns = at::empty(
      {n_weight_grps,
       n_in_channels * weight_w * weight_h / n_weight_grps,
       n_parallel_imgs * out_h * out_w},
      input.options());

  for (int elt = 0; elt < batch_sz / n_parallel_imgs; elt++) {
    deformable_im2col(
        input[elt],
        offset[elt],
        mask[elt],
        n_in_channels,
        in_h,
        in_w,
        weight_h,
        weight_w,
        pad_h,
        pad_w,
        stride_h,
        stride_w,
        dilation_h,
        dilation_w,
        out_h,
        out_w,
        n_parallel_imgs,
        n_offset_grps,
        use_mask,
        columns);

    for (int g = 0; g < n_weight_grps; g++) {
      grad_weight[g] =
          grad_weight[g]
              .flatten(1)
              .addmm_(
                  grad_out_buf[elt][g].flatten(1), columns[g].transpose(1, 0))
              .view_as(grad_weight[g]);
    }
  }

  grad_weight = grad_weight.view(
      {grad_weight.size(0) * grad_weight.size(1),
       grad_weight.size(2),
       grad_weight.size(3),
       grad_weight.size(4)});
  return grad_weight;
}