void backward_train()

in src/inplace_abn.cpp [139:174]


void backward_train(
    const at::Tensor& xhat,
    at::Tensor& dy,
    const at::Tensor& var,
    const at::Tensor& count,
    const at::Tensor& sum_dy,
    const at::Tensor& sum_xhat_dy,
    const c10::optional<at::Tensor>& weight,
    float eps) {
  // Check dimensions and types
  IABN_CHECK(xhat.ndimension() >= 2, "xhat should have at least 2 dimensions");
  IABN_CHECK(have_same_dims(xhat, dy), "xhat and dy should have the same size");
  CHECK_SAME_TYPE(xhat, dy);
  IABN_CHECK(
      is_compatible_stat(xhat, var),
      "var is not compatible with xhat (wrong size or scalar type)");
  IABN_CHECK(
      count.ndimension() == 1 && count.size(0) == 1,
      "count should be a vector with a single element");
  IABN_CHECK(
      count.scalar_type() == at::ScalarType::Long,
      "count should have type int64");
  IABN_CHECK(
      is_compatible_stat(xhat, sum_dy),
      "sum_dy is not compatible with xhat (wrong size or scalar type)");
  IABN_CHECK(
      is_compatible_stat(xhat, sum_xhat_dy),
      "sum_xhat_dy is not compatible with xhat (wrong size or scalar type)");
  if (weight.has_value())
    IABN_CHECK(
        is_compatible_weight(xhat, weight.value()),
        "weight is not compatible with xhat (wrong size or scalar type)");

  CUDA_DISPATCH(
      xhat, backward, xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps)
}