std::tuple native_group_norm_backward_plumbing()

in functorch/csrc/BatchRulesNorm.cpp [345:429]


std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
  const Tensor & grad_out, const Tensor & input, const Tensor & mean, 
  const Tensor & rstd, const c10::optional<Tensor> & weight_opt,
  int64_t N, int64_t C, int64_t HxW, int64_t group, std::array<bool,3> output_mask
) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;

  // plumbing
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();

  Tensor grad_out_value;
  optional<int64_t> grad_out_bdim;
  std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
  Tensor input_value;
  optional<int64_t> input_bdim;
  std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
  Tensor weight_value;
  optional<int64_t> weight_bdim;
  if (weight.defined()){
    std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
  }
  Tensor mean_value;
  optional<int64_t> mean_bdim;
  std::tie(mean_value, mean_bdim) = unwrapTensorAtLevel(mean, cur_level);
    Tensor rstd_value;
  optional<int64_t> rstd_bdim;
  std::tie(rstd_value, rstd_bdim) = unwrapTensorAtLevel(rstd, cur_level);

  // results
  Tensor grad_input;
  Tensor grad_weight;
  Tensor grad_bias;

  TORCH_INTERNAL_ASSERT(grad_out.dim() > 1);  // group_norm can't operate on 1D tensors so the output will be at least 2D
  if (output_mask[2]) {
    grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
  }

  if (output_mask[1] && weight.defined()) {
    const auto reshaped_input = reshape_dim_outof(1, group, input);
    const auto normalized_input = (reshaped_input - padRight(mean, nullopt, reshaped_input.dim())) * padRight(rstd, nullopt, reshaped_input.dim());
    const auto expanded_grad_weight = reshape_dim_into(1, 1, normalized_input) * grad_out;
    grad_weight = expanded_grad_weight.transpose(0, 1).sum(range(1, expanded_grad_weight.dim()));
  }

  if (output_mask[0]) {
    const auto grad_normalized_input = weight.defined() ?
      grad_out * padRight(weight, nullopt, grad_out.dim() - 1) : grad_out;
    Tensor grad_normalized_input_value;
    optional<int64_t> grad_normalized_input_bdim;
    std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
        unwrapTensorAtLevel(grad_normalized_input, cur_level);
    auto grad_out_ = moveBatchDimToFront(grad_normalized_input_value, grad_normalized_input_bdim);
    auto input_ = moveBatchDimToFront(input_value, input_bdim);
    auto mean_ = moveBatchDimToFront(mean_value, mean_bdim);
    auto rstd_ = moveBatchDimToFront(rstd_value, rstd_bdim);

    const auto bdim_size = get_bdim_size3(grad_out_, grad_out_bdim, input_, input_bdim, weight, weight_bdim);
    grad_out_ = ensure_has_bdim(grad_out_, grad_out_bdim.has_value(), bdim_size);
    input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
    mean_ = ensure_has_bdim(mean_, mean_bdim.has_value(), bdim_size);
    rstd_ = ensure_has_bdim(rstd_, rstd_bdim.has_value(), bdim_size);

    grad_out_ = reshape_dim_into(0, 0, grad_out_); // [B0 * N, C, *]
    input_ = reshape_dim_into(0, 0, input_);       // [B0 * N, C, *]
    mean_ = reshape_dim_into(0, 0, mean_);         // [B0 * N, G]
    rstd_ = reshape_dim_into(0, 0, rstd_);         // [B0 * N, G]

    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    const auto result = native_group_norm_backward(
        grad_out_,
        input_,
        mean_,
        rstd_,
        nullopt, N * bdim_size, C, HxW, group, {true, false, false});
    auto result0 = std::get<0>(result);
    result0 = reshape_dim_outof(0, bdim_size, result0);
    grad_input = makeBatched(result0, 0, cur_level);
  }
  return std::make_tuple(grad_input, grad_weight, grad_bias);
}