in functorch/csrc/BatchRulesNorm.cpp [345:429]
std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
const Tensor & grad_out, const Tensor & input, const Tensor & mean,
const Tensor & rstd, const c10::optional<Tensor> & weight_opt,
int64_t N, int64_t C, int64_t HxW, int64_t group, std::array<bool,3> output_mask
) {
// See [Note: hacky wrapper removal for optional tensor]
c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
const Tensor& weight = *weight_maybe_owned;
// plumbing
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
Tensor grad_out_value;
optional<int64_t> grad_out_bdim;
std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
Tensor input_value;
optional<int64_t> input_bdim;
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
Tensor weight_value;
optional<int64_t> weight_bdim;
if (weight.defined()){
std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
}
Tensor mean_value;
optional<int64_t> mean_bdim;
std::tie(mean_value, mean_bdim) = unwrapTensorAtLevel(mean, cur_level);
Tensor rstd_value;
optional<int64_t> rstd_bdim;
std::tie(rstd_value, rstd_bdim) = unwrapTensorAtLevel(rstd, cur_level);
// results
Tensor grad_input;
Tensor grad_weight;
Tensor grad_bias;
TORCH_INTERNAL_ASSERT(grad_out.dim() > 1); // group_norm can't operate on 1D tensors so the output will be at least 2D
if (output_mask[2]) {
grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
}
if (output_mask[1] && weight.defined()) {
const auto reshaped_input = reshape_dim_outof(1, group, input);
const auto normalized_input = (reshaped_input - padRight(mean, nullopt, reshaped_input.dim())) * padRight(rstd, nullopt, reshaped_input.dim());
const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out;
grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim()));
}
if (output_mask[0]) {
const auto grad_normalized_input = weight.defined() ?
grad_out * padRight(weight, nullopt, grad_out.dim() - 1) : grad_out;
Tensor grad_normalized_input_value;
optional<int64_t> grad_normalized_input_bdim;
std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
unwrapTensorAtLevel(grad_normalized_input, cur_level);
auto grad_out_ = moveBatchDimToFront(grad_normalized_input_value, grad_normalized_input_bdim);
auto input_ = moveBatchDimToFront(input_value, input_bdim);
auto mean_ = moveBatchDimToFront(mean_value, mean_bdim);
auto rstd_ = moveBatchDimToFront(rstd_value, rstd_bdim);
const auto bdim_size = get_bdim_size3(grad_out_, grad_out_bdim, input_, input_bdim, weight, weight_bdim);
grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);
grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
input_ = reshape_dim_into(0, 0, input_); // [B0 * N, C, *]
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto result = native_group_norm_backward(
grad_out_,
input_,
mean_,
rstd_,
nullopt, N * bdim_size, C, HxW, group, {true, false, false});
auto result0 = std::get<0>(result);
result0 = reshape_dim_outof(0, bdim_size, result0);
grad_input = makeBatched(result0, 0, cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
}