at::Tensor backward_test()

in src/inplace_abn.cpp [176:198]


at::Tensor backward_test(
    const at::Tensor& dy_,
    const at::Tensor& var,
    const c10::optional<at::Tensor>& weight,
    float eps) {
  // Check dimensions and types
  IABN_CHECK(dy_.ndimension() >= 2, "dy should have at least 2 dimensions");
  IABN_CHECK(
      is_compatible_stat(dy_, var),
      "var is not compatible with dy (wrong size or scalar type)");
  if (weight.has_value())
    IABN_CHECK(
        is_compatible_weight(dy_, weight.value()),
        "weight is not compatible with dy (wrong size or scalar type)");

  // TODO: optimize implementation for GPU
  auto dy = normalize_shape(dy_);
  auto mult = weight.has_value()
      ? (weight.value().to(var.options()).abs() + eps) / (var + eps).sqrt()
      : 1 / (var + eps).sqrt();
  auto dx = normalize_shape(mult) * dy.to(var.options());
  return dx.to(dy_.options()).view(dy_.sizes());
}