in backends/common/lib/compat/eigen/kernels/batch_norm_grad.cc [32:181]
static void BatchNormGrad(
// clang-format off
// Inputs --------------------------------------------------------------- //
ArgumentView<DHTIndexableView<T, 4>> output_grad,
ArgumentView<DHTIndexableView<T, 4>> input,
Argument<DenseHostTensor> gamma, // scale
Argument<DenseHostTensor> moving_mean,
Argument<DenseHostTensor> moving_variance,
Argument<Chain> chain_in,
// Outputs -------------------------------------------------------------- //
ArgumentView<MutableDHTIndexableView<T, 4>> input_grad,
Argument<DenseHostTensor> gamma_grad, // scale_grad
Argument<DenseHostTensor> beta_grad, // offset_grad
Result<Chain> input_grad_chain,
Result<Chain> gamma_grad_chain,
Result<Chain> beta_grad_chain,
// Attributes ----------------------------------------------------------- //
Attribute<float> epsilon,
// Execution context ---------------------------------------------------- //
KernelErrorHandler handler, const ExecutionContext& exec_ctx,
AsyncKernelFrame* frame) {
// clang-format on
// Note: the following formulas are used to compute the gradients for
// back propagation:
//
// output_grad = scale * rsqrt(variance + epsilon) *
// (output_grad - mean(output_grad) - (x - mean(x)) *
// mean(output_grad * (x - mean(x))) / (variance + epsilon))
//
// gamma_grad = sum(output_grad *
// (x - mean(x)) * rsqrt(variance + epsilon))
//
// beta_grad = sum(output_grad)
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("output_grad shape", output_grad->FixedShape(),
"input_grad shape", input_grad->FixedShape()));
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("input shape", input->FixedShape(),
"input_grad shape", input_grad->FixedShape()));
// Data format: (batch_size, height, width, num_channels) [NHWC]
const FixedRankShape<4>& input_grad_shape = input_grad->FixedShape();
const auto depth = input_grad_shape[3]; // num channels
const auto channels_shape = TensorShape{depth};
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("channels dimension size", channels_shape,
"gamma shape", gamma->shape()));
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("channels dimension size", channels_shape,
"mean shape", moving_mean->shape()));
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("channels dimension size", channels_shape,
"variance shape", moving_variance->shape()));
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("channels dimension size", channels_shape,
"gamma_grad shape", gamma_grad->shape()));
TFRT_RETURN_IF_ERROR(
handler, CheckShapeMatch("channels dimension size", channels_shape,
"beta_grad shape", beta_grad->shape()));
// Flatten all outer dimensions of input{grad}/output_grad.
const Index rest_size = output_grad->NumElements() / depth;
Eigen::DSizes<Eigen::Index, 2> rest_by_depth(rest_size, depth);
// Resize all vectors into 2d Tensors.
Eigen::DSizes<Eigen::Index, 2> one_by_depth(1, depth);
// Compute reductions (sum and mean) over outer dimension (after flattening).
Eigen::DSizes<int, 1> reduce_dims(0);
// Broadcast 1d vectors to the input/output shape.
Eigen::DSizes<int, 2> bcast_spec(rest_size, 1);
auto& ctx = exec_ctx.host()->GetOrCreateSharedContext<EigenHostContext>();
// Reshape input/output arguments into [rest_size, depth] tensors.
const FixedRankShape<2> rest_by_depth_s = AsShape(rest_by_depth);
auto output_grad_t = AsEigenConstTensor(output_grad.get(), rest_by_depth_s);
auto input_t = AsEigenConstTensor(input.get(), rest_by_depth_s);
auto input_grad_t = AsEigenTensor(input_grad.get(), rest_by_depth_s);
// Reshape input vectors into [1, depth] tensors.
const FixedRankShape<2> one_by_depth_s = AsShape(one_by_depth);
auto gamma_t = AsEigenConstTensor<T>(&*gamma, one_by_depth_s);
auto mean_t = AsEigenConstTensor<T>(&*moving_mean, one_by_depth_s);
auto variance_t = AsEigenConstTensor<T>(&*moving_variance, one_by_depth_s);
// Output gradients of [depth] shape.
auto gamma_grad_t = AsEigenTensor<T>(&*gamma_grad);
auto beta_grad_t = AsEigenTensor<T>(&*beta_grad);
T rest_size_inv = static_cast<T>(1.0f / static_cast<T>(rest_size));
auto coef0 = (variance_t + epsilon.get()).rsqrt(); // [1, depth]
auto coef1 = (gamma_t * coef0).eval().broadcast(bcast_spec); // [rest, depth]
auto input_centered = (input_t - mean_t.broadcast(bcast_spec));
auto input_scaled = input_centered * (coef0.eval().broadcast(bcast_spec));
// Allocate output chains for all results, because they must be not null
// before we copy the kernel frame below.
auto input_grad_ready = input_grad_chain.Allocate();
auto gamma_grad_ready = gamma_grad_chain.Allocate();
auto beta_grad_ready = beta_grad_chain.Allocate();
//=== gamma/scale gradient ----------------------------------------------===//
auto gamma_grad_expr = (output_grad_t * input_scaled).sum(reduce_dims);
AsyncAssign(ctx, std::move(gamma_grad_t), std::move(gamma_grad_expr),
[chain = std::move(gamma_grad_ready), frame = *frame]() {
chain.emplace();
});
//=== beta/offset gradient ----------------------------------------------===//
auto output_grad_sum = output_grad_t.sum(reduce_dims);
AsyncAssign(ctx, std::move(beta_grad_t), output_grad_sum,
[chain = std::move(beta_grad_ready), frame = *frame]() {
chain.emplace();
});
//=== input gradient ----------------------------------------------------===//
auto output_grad_sum_one_by_depth =
output_grad_sum.eval().reshape(one_by_depth);
auto output_grad_mean_one_by_depth =
output_grad_sum_one_by_depth * rest_size_inv;
auto output_grad_mean = output_grad_mean_one_by_depth.broadcast(bcast_spec);
auto output_grad_centered = output_grad_t - output_grad_mean;
auto coef2 =
(coef0.square() *
(output_grad_t * input_centered).mean(reduce_dims).reshape(one_by_depth))
.eval()
.broadcast(bcast_spec);
auto input_grad_expr =
coef1 * (output_grad_centered - input_centered * coef2);
AsyncAssign(ctx, std::move(input_grad_t), std::move(input_grad_expr),
[chain = std::move(input_grad_ready), frame = *frame]() {
chain.emplace();
});
}