std::tuple backward_reduce_impl()

in src/inplace_abn_cpu.cpp [25:87]


std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_impl(
    const at::Tensor& y_act_,
    const at::Tensor& dy_act_,
    const c10::optional<at::Tensor>& weight_,
    const c10::optional<at::Tensor>& bias_,
    float eps,
    float activation_param) {
  // Initialize output tensors
  auto xhat_ = at::empty_like(y_act_);
  auto dy_ = at::empty_like(y_act_);
  auto sum_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());
  auto sum_xhat_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());

  // Normalize shapes
  auto y_act_norm_ = normalize_shape(y_act_);
  auto dy_act_norm_ = normalize_shape(dy_act_);
  auto xhat_norm_ = normalize_shape(xhat_);
  auto dy_norm_ = normalize_shape(dy_);

  // Get dimensions
  int64_t num = y_act_norm_.size(0), chn = y_act_norm_.size(1),
          sp = y_act_norm_.size(2);

  // Make accessors
  auto y_act = y_act_norm_.accessor<scalar_t, 3>();
  auto dy_act = dy_act_norm_.accessor<scalar_t, 3>();
  auto xhat = xhat_norm_.accessor<scalar_t, 3>();
  auto dy = dy_norm_.accessor<scalar_t, 3>();
  auto weight = accessor_or_dummy<scalar_t, 1>(weight_);
  auto bias = accessor_or_dummy<scalar_t, 1>(bias_);
  auto sum_dy = sum_dy_.accessor<scalar_t, 1>();
  auto sum_xhat_dy = sum_xhat_dy_.accessor<scalar_t, 1>();

  // Main loop
  for (int64_t c = 0; c < chn; ++c) {
    auto inv_gamma_c = weight_.has_value()
        ? scalar_t(1) / (std::abs(weight[c]) + eps)
        : scalar_t(1);
    auto beta_c = bias_.has_value() ? bias[c] : scalar_t(0);

    for (int64_t n = 0; n < num; ++n) {
      auto y_act_nc = y_act[n][c];
      auto dy_act_nc = dy_act[n][c];
      auto xhat_nc = xhat[n][c];
      auto dy_nc = dy[n][c];

      for (int64_t s = 0; s < sp; ++s) {
        // Invert activation
        ActivationFn<scalar_t, activation>::backward(
            y_act_nc[s], dy_act_nc[s], activation_param, xhat_nc[s], dy_nc[s]);

        // Invert affine transformation
        xhat_nc[s] = (xhat_nc[s] - beta_c) * inv_gamma_c;

        // Accumulate
        sum_dy[c] += dy_nc[s];
        sum_xhat_dy[c] += xhat_nc[s] * dy_nc[s];
      }
    }
  }

  return std::make_tuple(xhat_, dy_, sum_dy_, sum_xhat_dy_);
}