in include/tvm/topi/nn/pooling.h [49:211]
inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
const Array<PrimExpr>& padding_size, PoolType pool_type,
bool ceil_mode, const size_t height_axis, const size_t width_axis,
bool count_include_pad) {
ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";
auto kernel_height = kernel_size[0];
auto kernel_width = kernel_size[1];
auto stride_height = stride_size[0];
auto stride_width = stride_size[1];
auto height = x->shape[height_axis];
auto width = x->shape[width_axis];
auto pad_top = padding_size[0];
auto pad_left = padding_size[1];
auto pad_bottom = padding_size[2];
auto pad_right = padding_size[3];
if (ceil_mode) {
// Additional padding to ensure we do ceil instead of floor when
// dividing by stride.
pad_bottom += stride_height - 1;
pad_right += stride_width - 1;
}
Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
pad_before.Set(height_axis, pad_top);
pad_before.Set(width_axis, pad_left);
Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
pad_after.Set(height_axis, pad_bottom);
pad_after.Set(width_axis, pad_right);
arith::Analyzer analyzer;
auto out_height =
analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
auto out_width =
analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);
auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");
Array<PrimExpr> data_shape = x->shape;
Array<PrimExpr> out_shape = data_shape;
out_shape.Set(height_axis, out_height);
out_shape.Set(width_axis, out_width);
const int64_t* padding_h0 = as_const_int(pad_top);
const int64_t* padding_w0 = as_const_int(pad_left);
const int64_t* padding_h1 = as_const_int(pad_bottom);
const int64_t* padding_w1 = as_const_int(pad_right);
const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));
if (pool_type == kMaxPool) {
Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);
auto windowh =
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
auto windoww =
tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
auto argmax = MakeArgmaxReducer();
auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;
auto mp_argmax = tvm::te::compute(
out_shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> window_inds{inds.begin(), inds.end()};
window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
auto idx = detail::RavelIndex(window_inds, ravel_shape);
return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
},
"maxpool_grad_argmax", kCommReduceIdx);
auto mp_inds = mp_argmax[0];
return tvm::te::compute(
data_shape,
[&](const Array<Var>& inds) {
Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
auto idx = detail::RavelIndex(pad_inds, ravel_shape);
Array<PrimExpr> out_idx{inds.begin(), inds.end()};
out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);
PrimExpr out_idx_lower_h = tir::Select(
pad_inds[height_axis] < kernel_height, make_const(pad_inds[height_axis].dtype(), 0),
(pad_inds[height_axis] - kernel_height) / stride_height + 1);
PrimExpr out_idx_lower_w = tir::Select(
pad_inds[width_axis] < kernel_width, make_const(pad_inds[width_axis].dtype(), 0),
(pad_inds[width_axis] - kernel_width) / stride_width + 1);
return tvm::sum(
tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
out_idx[width_axis] >= out_idx_lower_w),
mp_inds(out_idx) == idx),
out_grad(out_idx), make_const(x->dtype, 0)),
{windowh, windoww});
},
"T_pool_grad", "pool_grad_max");
} else if (pool_type == kAvgPool) {
auto windowh =
tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
auto windoww =
tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
return tvm::te::compute(
data_shape,
[&](const Array<Var>& inds) {
PrimExpr pad_h_idx = inds[height_axis] + pad_top;
PrimExpr pad_w_idx = inds[width_axis] + pad_left;
// output indices whose pooling windows cover current input element (can be out-of-bound)
Array<PrimExpr> out_idx{inds.begin(), inds.end()};
out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));
PrimExpr out_idx_lower_h =
tir::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0),
(pad_h_idx - kernel_height) / stride_height + 1);
PrimExpr out_idx_lower_w =
tir::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0),
(pad_w_idx - kernel_width) / stride_width + 1);
PrimExpr divide_factor; // number of pooled elements
if (count_include_pad) {
divide_factor = kernel_height * kernel_width;
} else {
PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;
PrimExpr h_end = min(h_start + kernel_height, height);
PrimExpr w_end = min(w_start + kernel_width, width);
h_start = max(h_start, make_const(h_start.dtype(), 0));
w_start = max(w_start, make_const(w_start.dtype(), 0));
divide_factor =
max((h_end - h_start) * (w_end - w_start), make_const(h_end.dtype(), 1));
}
return tvm::sum(
tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
out_idx[height_axis] < out_height),
tir::And(out_idx[width_axis] >= out_idx_lower_w,
out_idx[width_axis] < out_width)),
out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
{windowh, windoww});
},
"T_pool_grad", "pool_grad_avg");
} else {
LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
return Tensor();
}
}