in src/inplace_abn_cpu.cpp [172:195]
void backward_cpu(
const at::Tensor& xhat_,
at::Tensor& dy_,
const at::Tensor& var,
const at::Tensor& count,
const at::Tensor& sum_dy,
const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight,
float eps) {
CHECK_NOT_HALF(xhat_);
auto xhat = normalize_shape(xhat_);
auto dy = normalize_shape(dy_);
auto mean_dy = normalize_shape(sum_dy / count.to(sum_dy.options()));
auto mean_xhat_dy =
normalize_shape(sum_xhat_dy / count.to(sum_xhat_dy.options()));
auto mult = weight.has_value()
? (weight.value().abs() + eps) / (var + eps).sqrt()
: 1 / (var + eps).sqrt();
// dy = (dy - mean_dy - xhat * mean_xhat_dy) * mult
dy.sub_(mean_dy).sub_(xhat * mean_xhat_dy).mul_(normalize_shape(mult));
}