in torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp [867:1086]
at::Tensor deform_conv2d_forward_kernel(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t n_weight_grps,
int64_t n_offset_grps,
bool use_mask) {
at::Tensor input_c = input.contiguous();
at::Tensor offset_c = offset.contiguous();
at::Tensor weight_c = weight.contiguous();
at::Tensor mask_c = mask.contiguous();
at::Tensor bias_c = bias.contiguous();
TORCH_CHECK(input_c.ndimension() == 4);
TORCH_CHECK(offset_c.ndimension() == 4);
TORCH_CHECK(!use_mask || mask_c.ndimension() == 4);
TORCH_CHECK(weight_c.ndimension() == 4);
TORCH_CHECK(input_c.device().is_cpu(), "input must be a CPU tensor");
int batch_sz = input_c.size(0);
int n_in_channels = input_c.size(1);
int in_h = input_c.size(2);
int in_w = input_c.size(3);
int n_parallel_imgs =
get_greatest_divisor_below_bound(batch_sz, kMaxParallelImgs);
// Unpack shapes and args
int out_channels = weight_c.size(0);
int weight_h = weight_c.size(2);
int weight_w = weight_c.size(3);
int ker_h = dilation_h * (weight_h - 1) + 1;
int ker_w = dilation_w * (weight_w - 1) + 1;
int out_h = ((in_h + 2 * pad_h - ker_h) / stride_h) + 1;
int out_w = ((in_w + 2 * pad_w - ker_w) / stride_w) + 1;
TORCH_CHECK(
weight_h > 0 && weight_w > 0,
"weight_h: ",
weight_h,
" weight_w: ",
weight_w);
TORCH_CHECK(
stride_h > 0 && stride_w > 0,
"stride_h: ",
stride_h,
" stride_w: ",
stride_w);
TORCH_CHECK(pad_h >= 0 && pad_w >= 0, "pad_h: ", pad_h, " pad_w: ", pad_w);
TORCH_CHECK(
dilation_h > 0 && dilation_w > 0,
"dilation_h: ",
dilation_h,
" dilation_w: ",
dilation_w);
TORCH_CHECK(weight_c.size(1) * n_weight_grps == input_c.size(1));
TORCH_CHECK(weight_c.size(0) % n_weight_grps == 0);
TORCH_CHECK(
(offset_c.size(1) == n_offset_grps * 2 * weight_h * weight_w),
"offset.shape[1] is not valid: got: ",
offset_c.size(1),
" expected: ",
n_offset_grps * 2 * weight_h * weight_w);
TORCH_CHECK(
(!use_mask || mask_c.size(1) == n_offset_grps * weight_h * weight_w),
"mask.shape[1] is not valid: got: ",
mask_c.size(1),
" expected: ",
n_offset_grps * weight_h * weight_w);
TORCH_CHECK(input_c.size(1) % n_offset_grps == 0);
TORCH_CHECK(
(offset_c.size(0) == input_c.size(0)), "invalid batch size of offset");
TORCH_CHECK(
(offset_c.size(2) == out_h && offset_c.size(3) == out_w),
"offset output dims: (",
offset_c.size(2),
", ",
offset_c.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
(mask_c.size(0) == input_c.size(0)), "invalid batch size of mask");
TORCH_CHECK(
(!use_mask || (mask_c.size(2) == out_h && mask_c.size(3) == out_w)),
"mask output dims: (",
mask_c.size(2),
", ",
mask_c.size(3),
") - ",
"computed output dims: (",
out_h,
", ",
out_w,
")");
TORCH_CHECK(
out_h > 0 && out_w > 0,
"Calculated output size too small - out_h: ",
out_h,
" out_w: ",
out_w);
auto out =
at::zeros({batch_sz, out_channels, out_h, out_w}, input_c.options());
if (batch_sz == 0) {
return out;
}
// Separate batches into blocks
out = out.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
out_channels,
out_h,
out_w});
input_c = input_c.view(
{batch_sz / n_parallel_imgs, n_parallel_imgs, n_in_channels, in_h, in_w});
offset_c = offset_c.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * 2 * weight_h * weight_w,
out_h,
out_w});
if (use_mask) {
mask_c = mask_c.view(
{batch_sz / n_parallel_imgs,
n_parallel_imgs,
n_offset_grps * weight_h * weight_w,
out_h,
out_w});
}
at::Tensor out_buf = at::zeros(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs * out_h,
out_w},
out.options());
// Separate channels into convolution groups
out_buf = out_buf.view(
{out_buf.size(0),
n_weight_grps,
out_buf.size(1) / n_weight_grps,
out_buf.size(2),
out_buf.size(3)});
weight_c = weight_c.view(
{n_weight_grps,
weight_c.size(0) / n_weight_grps,
weight_c.size(1),
weight_c.size(2),
weight_c.size(3)});
// Sample points and perform convolution
auto columns = at::zeros(
{n_in_channels * weight_h * weight_w, n_parallel_imgs * out_h * out_w},
input_c.options());
for (int b = 0; b < batch_sz / n_parallel_imgs; b++) {
deformable_im2col(
input_c[b],
offset_c[b],
mask_c[b],
n_in_channels,
in_h,
in_w,
weight_h,
weight_w,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
out_h,
out_w,
n_parallel_imgs,
n_offset_grps,
use_mask,
columns);
columns = columns.view(
{n_weight_grps, columns.size(0) / n_weight_grps, columns.size(1)});
for (int g = 0; g < n_weight_grps; g++) {
out_buf[b][g] = out_buf[b][g]
.flatten(1)
.addmm_(weight_c[g].flatten(1), columns[g])
.view_as(out_buf[b][g]);
}
columns =
columns.view({columns.size(0) * columns.size(1), columns.size(2)});
}
out_buf = out_buf.view(
{batch_sz / n_parallel_imgs,
out_channels,
n_parallel_imgs,
out_h,
out_w});
out_buf.transpose_(1, 2);
out.copy_(out_buf);
out = out.view({batch_sz, out_channels, out_h, out_w});
return out + bias_c.view({1, out_channels, 1, 1});
}