in torchvision/csrc/ops/cpu/deform_conv2d_kernel.cpp [417:527]
void deformable_col2im_coord_kernel(
int n,
const scalar_t* col,
const scalar_t* im,
const scalar_t* offset,
const scalar_t* mask,
int channels,
int height,
int width,
int weight_h,
int weight_w,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int batch_sz,
int offset_channels,
int n_offset_grps,
int out_h,
int out_w,
bool use_mask,
scalar_t* grad_offset,
scalar_t* grad_mask) {
for (int index = 0; index != n; ++index) {
scalar_t grad_offset_val = 0;
scalar_t grad_mask_val = 0;
int w = index % out_w;
int h = (index / out_w) % out_h;
int w_w = (index / (out_w * out_h * 2)) % weight_w;
int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h;
int c = (index / (out_w * out_h)) % offset_channels;
int b = index / (out_w * out_h * offset_channels);
const int offset_grp = c / (2 * weight_h * weight_w);
const int col_step = weight_h * weight_w;
int c_per_offset_grp = channels / n_offset_grps;
auto col_ptr = col +
offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w *
out_h;
auto im_ptr = im +
(b * n_offset_grps + offset_grp) * c_per_offset_grp * height * width;
auto offset_ptr = offset +
(b * n_offset_grps + offset_grp) * 2 * weight_h * weight_w * out_h *
out_w;
auto mask_ptr = mask;
if (use_mask) {
mask_ptr += (b * n_offset_grps + offset_grp) * weight_h * weight_w *
out_h * out_w;
}
const int offset_c = c - offset_grp * 2 * weight_h * weight_w;
const bool is_y_direction = offset_c % 2 == 0;
const int c_bound = c_per_offset_grp * weight_h * weight_w;
for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) {
const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w;
int out_x = col_pos % out_w;
int out_y = (col_pos / out_w) % out_h;
int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w;
int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h;
const int mask_idx = i * weight_w + j;
const int offset_h_idx =
(((2 * mask_idx) * out_h + out_y) * out_w + out_x);
const int offset_w_idx =
(((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x);
const scalar_t offset_h = offset_ptr[offset_h_idx];
const scalar_t offset_w = offset_ptr[offset_w_idx];
scalar_t mask_value = 1;
if (use_mask) {
mask_value = mask_ptr[(mask_idx * out_h + out_y) * out_w + out_x];
}
scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h;
scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w;
const scalar_t weight =
get_coordinate_weight(im_ptr, height, width, y, x, is_y_direction);
grad_offset_val += mask_value * weight * col_ptr[col_pos];
if (use_mask && is_y_direction) {
grad_mask_val += col_ptr[col_pos] *
bilinear_interpolate(im_ptr, height, width, y, x);
}
im_ptr += height * width;
}
grad_offset[index] = grad_offset_val;
if (use_mask && is_y_direction) {
const int idx =
((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w +
w_w) *
out_h +
h) *
out_w +
w;
grad_mask[idx] = grad_mask_val;
}
}
}