void backward_cpu()

in src/inplace_abn_cpu.cpp [172:195]


void backward_cpu(
    const at::Tensor& xhat_,
    at::Tensor& dy_,
    const at::Tensor& var,
    const at::Tensor& count,
    const at::Tensor& sum_dy,
    const at::Tensor& sum_xhat_dy,
    const c10::optional<at::Tensor>& weight,
    float eps) {
  CHECK_NOT_HALF(xhat_);

  auto xhat = normalize_shape(xhat_);
  auto dy = normalize_shape(dy_);
  auto mean_dy = normalize_shape(sum_dy / count.to(sum_dy.options()));
  auto mean_xhat_dy =
      normalize_shape(sum_xhat_dy / count.to(sum_xhat_dy.options()));

  auto mult = weight.has_value()
      ? (weight.value().abs() + eps) / (var + eps).sqrt()
      : 1 / (var + eps).sqrt();

  // dy = (dy - mean_dy - xhat * mean_xhat_dy) * mult
  dy.sub_(mean_dy).sub_(xhat * mean_xhat_dy).mul_(normalize_shape(mult));
}