in src/inplace_abn.cpp [139:174]
void backward_train(
const at::Tensor& xhat,
at::Tensor& dy,
const at::Tensor& var,
const at::Tensor& count,
const at::Tensor& sum_dy,
const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight,
float eps) {
// Check dimensions and types
IABN_CHECK(xhat.ndimension() >= 2, "xhat should have at least 2 dimensions");
IABN_CHECK(have_same_dims(xhat, dy), "xhat and dy should have the same size");
CHECK_SAME_TYPE(xhat, dy);
IABN_CHECK(
is_compatible_stat(xhat, var),
"var is not compatible with xhat (wrong size or scalar type)");
IABN_CHECK(
count.ndimension() == 1 && count.size(0) == 1,
"count should be a vector with a single element");
IABN_CHECK(
count.scalar_type() == at::ScalarType::Long,
"count should have type int64");
IABN_CHECK(
is_compatible_stat(xhat, sum_dy),
"sum_dy is not compatible with xhat (wrong size or scalar type)");
IABN_CHECK(
is_compatible_stat(xhat, sum_xhat_dy),
"sum_xhat_dy is not compatible with xhat (wrong size or scalar type)");
if (weight.has_value())
IABN_CHECK(
is_compatible_weight(xhat, weight.value()),
"weight is not compatible with xhat (wrong size or scalar type)");
CUDA_DISPATCH(
xhat, backward, xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, eps)
}