inline Tensor pool_grad_impl()

in include/tvm/topi/nn/pooling.h [49:211]


inline Tensor pool_grad_impl(const Tensor& out_grad, const Tensor& x,
                             const Array<PrimExpr>& kernel_size, const Array<PrimExpr>& stride_size,
                             const Array<PrimExpr>& padding_size, PoolType pool_type,
                             bool ceil_mode, const size_t height_axis, const size_t width_axis,
                             bool count_include_pad) {
  ICHECK(out_grad->shape.size() >= 2) << "Pooling grad output must >= 2-D (H, W)";
  ICHECK(x->shape.size() >= 2) << "Pooling input must >= 2-D (H, W)";
  ICHECK_EQ(kernel_size.size(), 2) << "Pooling kernel_size must have 2 elements";
  ICHECK_EQ(stride_size.size(), 2) << "Pooling stride_size must have 2 elements";
  ICHECK_EQ(padding_size.size(), 4) << "Pooling padding_size must have 4 elements";

  auto kernel_height = kernel_size[0];
  auto kernel_width = kernel_size[1];
  auto stride_height = stride_size[0];
  auto stride_width = stride_size[1];

  auto height = x->shape[height_axis];
  auto width = x->shape[width_axis];

  auto pad_top = padding_size[0];
  auto pad_left = padding_size[1];
  auto pad_bottom = padding_size[2];
  auto pad_right = padding_size[3];

  if (ceil_mode) {
    // Additional padding to ensure we do ceil instead of floor when
    // dividing by stride.
    pad_bottom += stride_height - 1;
    pad_right += stride_width - 1;
  }

  Array<PrimExpr> pad_before(std::vector<PrimExpr>(x->shape.size(), 0));
  pad_before.Set(height_axis, pad_top);
  pad_before.Set(width_axis, pad_left);

  Array<PrimExpr> pad_after(std::vector<PrimExpr>(x->shape.size(), 0));
  pad_after.Set(height_axis, pad_bottom);
  pad_after.Set(width_axis, pad_right);
  arith::Analyzer analyzer;
  auto out_height =
      analyzer.Simplify((height - kernel_height + pad_top + pad_bottom) / stride_height + 1);
  auto out_width =
      analyzer.Simplify((width - kernel_width + pad_left + pad_right) / stride_width + 1);

  auto dheight = tvm::te::reduce_axis(Range(0, kernel_height), "dh");
  auto dwidth = tvm::te::reduce_axis(Range(0, kernel_width), "dw");

  Array<PrimExpr> data_shape = x->shape;
  Array<PrimExpr> out_shape = data_shape;
  out_shape.Set(height_axis, out_height);
  out_shape.Set(width_axis, out_width);

  const int64_t* padding_h0 = as_const_int(pad_top);
  const int64_t* padding_w0 = as_const_int(pad_left);
  const int64_t* padding_h1 = as_const_int(pad_bottom);
  const int64_t* padding_w1 = as_const_int(pad_right);
  const bool do_pad = ((padding_h0 && *padding_h0) || (padding_w0 && *padding_w0)) ||
                      ((padding_h1 && *padding_h1) || (padding_w1 && *padding_w1));

  if (pool_type == kMaxPool) {
    Array<PrimExpr> ravel_shape{data_shape.begin(), data_shape.end()};
    ravel_shape.Set(height_axis, ravel_shape[height_axis] + pad_top + pad_bottom);
    ravel_shape.Set(width_axis, ravel_shape[width_axis] + pad_left + pad_right);

    auto windowh =
        tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
    auto windoww =
        tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");

    auto argmax = MakeArgmaxReducer();
    auto pad_x = do_pad ? pad(x, pad_before, pad_after, tvm::min_value(x->dtype), "pad_temp") : x;

    auto mp_argmax = tvm::te::compute(
        out_shape,
        [&](const Array<Var>& inds) {
          Array<PrimExpr> window_inds{inds.begin(), inds.end()};
          window_inds.Set(height_axis, inds[height_axis] * stride_height + dheight);
          window_inds.Set(width_axis, inds[width_axis] * stride_width + dwidth);
          auto idx = detail::RavelIndex(window_inds, ravel_shape);
          return argmax({idx, pad_x(window_inds)}, {dheight, dwidth}, nullptr);
        },
        "maxpool_grad_argmax", kCommReduceIdx);

    auto mp_inds = mp_argmax[0];

    return tvm::te::compute(
        data_shape,
        [&](const Array<Var>& inds) {
          Array<PrimExpr> pad_inds{inds.begin(), inds.end()};
          pad_inds.Set(height_axis, pad_inds[height_axis] + pad_top);
          pad_inds.Set(width_axis, pad_inds[width_axis] + pad_left);
          auto idx = detail::RavelIndex(pad_inds, ravel_shape);

          Array<PrimExpr> out_idx{inds.begin(), inds.end()};
          out_idx.Set(height_axis, (inds[height_axis] + pad_top) / stride_height - windowh);
          out_idx.Set(width_axis, (inds[width_axis] + pad_left) / stride_width - windoww);

          PrimExpr out_idx_lower_h = tir::Select(
              pad_inds[height_axis] < kernel_height, make_const(pad_inds[height_axis].dtype(), 0),
              (pad_inds[height_axis] - kernel_height) / stride_height + 1);
          PrimExpr out_idx_lower_w = tir::Select(
              pad_inds[width_axis] < kernel_width, make_const(pad_inds[width_axis].dtype(), 0),
              (pad_inds[width_axis] - kernel_width) / stride_width + 1);

          return tvm::sum(
              tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
                                                  out_idx[width_axis] >= out_idx_lower_w),
                                         mp_inds(out_idx) == idx),
                                out_grad(out_idx), make_const(x->dtype, 0)),
              {windowh, windoww});
        },
        "T_pool_grad", "pool_grad_max");
  } else if (pool_type == kAvgPool) {
    auto windowh =
        tvm::te::reduce_axis(Range(0, (kernel_height + stride_height - 1) / stride_height), "wh");
    auto windoww =
        tvm::te::reduce_axis(Range(0, (kernel_width + stride_width - 1) / stride_width), "ww");
    return tvm::te::compute(
        data_shape,
        [&](const Array<Var>& inds) {
          PrimExpr pad_h_idx = inds[height_axis] + pad_top;
          PrimExpr pad_w_idx = inds[width_axis] + pad_left;

          // output indices whose pooling windows cover current input element (can be out-of-bound)
          Array<PrimExpr> out_idx{inds.begin(), inds.end()};
          out_idx.Set(height_axis, (pad_h_idx / stride_height - windowh));
          out_idx.Set(width_axis, (pad_w_idx / stride_width - windoww));

          PrimExpr out_idx_lower_h =
              tir::Select(pad_h_idx < kernel_height, make_const(pad_h_idx.dtype(), 0),
                          (pad_h_idx - kernel_height) / stride_height + 1);
          PrimExpr out_idx_lower_w =
              tir::Select(pad_w_idx < kernel_width, make_const(pad_w_idx.dtype(), 0),
                          (pad_w_idx - kernel_width) / stride_width + 1);

          PrimExpr divide_factor;  // number of pooled elements
          if (count_include_pad) {
            divide_factor = kernel_height * kernel_width;
          } else {
            PrimExpr h_start = out_idx[height_axis] * stride_height - pad_top;
            PrimExpr w_start = out_idx[width_axis] * stride_width - pad_left;

            PrimExpr h_end = min(h_start + kernel_height, height);
            PrimExpr w_end = min(w_start + kernel_width, width);
            h_start = max(h_start, make_const(h_start.dtype(), 0));
            w_start = max(w_start, make_const(w_start.dtype(), 0));
            divide_factor =
                max((h_end - h_start) * (w_end - w_start), make_const(h_end.dtype(), 1));
          }
          return tvm::sum(
              tvm::if_then_else(tir::And(tir::And(out_idx[height_axis] >= out_idx_lower_h,
                                                  out_idx[height_axis] < out_height),
                                         tir::And(out_idx[width_axis] >= out_idx_lower_w,
                                                  out_idx[width_axis] < out_width)),
                                out_grad(out_idx) / divide_factor, make_const(out_grad->dtype, 0)),
              {windowh, windoww});
        },
        "T_pool_grad", "pool_grad_avg");
  } else {
    LOG(ERROR) << "Unrecognized pool_type: " << pool_type;
    return Tensor();
  }
}