in functorch/csrc/BatchRulesNorm.cpp [564:649]
std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_plumbing(
const at::Tensor & grad_out,
const at::Tensor & input,
at::IntArrayRef normalized_shape,
const at::Tensor & mean,
const at::Tensor & rstd,
const c10::optional<at::Tensor> & weight_opt,
const c10::optional<at::Tensor> & bias_opt,
std::array<bool,3> output_mask) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
c10::MaybeOwned<Tensor> bias_maybe_owned = at::borrow_from_optional_tensor(bias_opt);
const Tensor& bias = *bias_maybe_owned;
// plumbing
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
Tensor grad_out_value;
optional<int64_t> grad_out_bdim;
std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
Tensor input_value;
optional<int64_t> input_bdim;
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
Tensor mean_value;
optional<int64_t> mean_bdim;
std::tie(mean_value, mean_bdim) = unwrapTensorAtLevel(mean, cur_level);
Tensor rstd_value;
optional<int64_t> rstd_bdim;
std::tie(rstd_value, rstd_bdim) = unwrapTensorAtLevel(rstd, cur_level);
optional<Tensor> weight_value;
optional<int64_t> weight_bdim;
if (weight.defined()) {
std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
}
optional<Tensor> bias_value;
optional<int64_t> bias_bdim;
if (bias.defined()) {
std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias, cur_level);
}
// results
Tensor grad_bias;
Tensor grad_weight;
Tensor grad_input;
if (output_mask[2] && bias_value.has_value()) {
const auto num_front_dims_to_reduce = grad_out.dim() - normalized_shape.size();
if (num_front_dims_to_reduce == 0) {
grad_bias = grad_out;
} else {
grad_bias = grad_out.sum(range(0, num_front_dims_to_reduce));
}
}
if (output_mask[1] && weight_value.has_value()) {
// NB: output isn't saved...
const auto normalized_input = (input - mean) * rstd;
const auto expanded_grad_weight = normalized_input * grad_out;
const auto num_front_dims_to_reduce =
expanded_grad_weight.dim() - normalized_shape.size();
if (num_front_dims_to_reduce == 0) {
grad_weight = expanded_grad_weight;
} else {
grad_weight = expanded_grad_weight.sum(range(0, num_front_dims_to_reduce));
}
}
if (output_mask[0]) {
const auto grad_normalized_input = weight.defined() ?
grad_out * weight : grad_out;
Tensor grad_normalized_input_value;
optional<int64_t> grad_normalized_input_bdim;
std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
unwrapTensorAtLevel(grad_normalized_input, cur_level);
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
grad_normalized_input_value, grad_normalized_input_bdim,
input_value, input_bdim,
normalized_shape,
mean_value, mean_bdim,
rstd_value, rstd_bdim);
grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}