in src/inplace_abn_cpu.cpp [25:87]
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_impl(
const at::Tensor& y_act_,
const at::Tensor& dy_act_,
const c10::optional<at::Tensor>& weight_,
const c10::optional<at::Tensor>& bias_,
float eps,
float activation_param) {
// Initialize output tensors
auto xhat_ = at::empty_like(y_act_);
auto dy_ = at::empty_like(y_act_);
auto sum_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());
auto sum_xhat_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());
// Normalize shapes
auto y_act_norm_ = normalize_shape(y_act_);
auto dy_act_norm_ = normalize_shape(dy_act_);
auto xhat_norm_ = normalize_shape(xhat_);
auto dy_norm_ = normalize_shape(dy_);
// Get dimensions
int64_t num = y_act_norm_.size(0), chn = y_act_norm_.size(1),
sp = y_act_norm_.size(2);
// Make accessors
auto y_act = y_act_norm_.accessor<scalar_t, 3>();
auto dy_act = dy_act_norm_.accessor<scalar_t, 3>();
auto xhat = xhat_norm_.accessor<scalar_t, 3>();
auto dy = dy_norm_.accessor<scalar_t, 3>();
auto weight = accessor_or_dummy<scalar_t, 1>(weight_);
auto bias = accessor_or_dummy<scalar_t, 1>(bias_);
auto sum_dy = sum_dy_.accessor<scalar_t, 1>();
auto sum_xhat_dy = sum_xhat_dy_.accessor<scalar_t, 1>();
// Main loop
for (int64_t c = 0; c < chn; ++c) {
auto inv_gamma_c = weight_.has_value()
? scalar_t(1) / (std::abs(weight[c]) + eps)
: scalar_t(1);
auto beta_c = bias_.has_value() ? bias[c] : scalar_t(0);
for (int64_t n = 0; n < num; ++n) {
auto y_act_nc = y_act[n][c];
auto dy_act_nc = dy_act[n][c];
auto xhat_nc = xhat[n][c];
auto dy_nc = dy[n][c];
for (int64_t s = 0; s < sp; ++s) {
// Invert activation
ActivationFn<scalar_t, activation>::backward(
y_act_nc[s], dy_act_nc[s], activation_param, xhat_nc[s], dy_nc[s]);
// Invert affine transformation
xhat_nc[s] = (xhat_nc[s] - beta_c) * inv_gamma_c;
// Accumulate
sum_dy[c] += dy_nc[s];
sum_xhat_dy[c] += xhat_nc[s] * dy_nc[s];
}
}
}
return std::make_tuple(xhat_, dy_, sum_dy_, sum_xhat_dy_);
}