std::tuple batch_norm_backward_plumbing()

in functorch/csrc/BatchRulesNorm.cpp [198:292]


std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
    const at::Tensor & grad_out,
    const at::Tensor & input,
    const c10::optional<at::Tensor> & weight_opt,
    const c10::optional<at::Tensor> & running_mean_opt,
    const c10::optional<at::Tensor> & running_var_opt,
    const c10::optional<at::Tensor> & save_mean_opt,
    const c10::optional<at::Tensor> & save_rstd_opt,
    bool training,
    double eps,
    std::array<bool,3> output_mask) {
  // See [Note: hacky wrapper removal for optional tensor]
  c10::MaybeOwned<Tensor> weight_maybe_owned = at::borrow_from_optional_tensor(weight_opt);
  const Tensor& weight = *weight_maybe_owned;
  c10::MaybeOwned<Tensor> running_mean_maybe_owned = at::borrow_from_optional_tensor(running_mean_opt);
  const Tensor& running_mean = *running_mean_maybe_owned;
  c10::MaybeOwned<Tensor> running_var_maybe_owned = at::borrow_from_optional_tensor(running_var_opt);
  const Tensor& running_var = *running_var_maybe_owned;
  // NB: not sure why these are optional...these are required from the forward
  const Tensor& save_mean = *save_mean_opt;
  const Tensor& save_rstd = *save_rstd_opt;
  TORCH_INTERNAL_ASSERT(save_mean.defined());
  TORCH_INTERNAL_ASSERT(save_rstd.defined());

  // plumbing
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();
  Tensor grad_out_value;
  optional<int64_t> grad_out_bdim;
  std::tie(grad_out_value, grad_out_bdim) = unwrapTensorAtLevel(grad_out, cur_level);
  Tensor input_value;
  optional<int64_t> input_bdim;
  std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
  Tensor mean_value;
  optional<Tensor> weight_value;
  optional<int64_t> weight_bdim;
  if (weight.defined()) {
    std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight, cur_level);
  }
  optional<Tensor> running_mean_value;
  optional<int64_t> running_mean_bdim;
  if (running_mean.defined()) {
    std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean, cur_level);
  }
  optional<Tensor> running_var_value;
  optional<int64_t> running_var_bdim;
  if (running_var.defined()) {
    std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var, cur_level);
  }
  Tensor save_mean_value;
  optional<int64_t> save_mean_bdim;
  std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean, cur_level);
  Tensor save_rstd_value;
  optional<int64_t> save_rstd_bdim;
  std::tie(save_rstd_value, save_rstd_bdim) = unwrapTensorAtLevel(save_rstd, cur_level);

  // results
  Tensor grad_bias;
  Tensor grad_weight;
  Tensor grad_input;

  TORCH_INTERNAL_ASSERT(grad_out_value.dim() > 1);  // batch_norm can't operate on 1D tensors so the output will be at least 2D
  if (output_mask[2]) {
    grad_bias = grad_out.transpose(0, 1).sum(range(1, grad_out.dim()));
  }
  if (output_mask[1] && weight_value.has_value()) {
    // NB: output isn't saved...
    auto mean = training ? save_mean : running_mean;
    auto var = training ? save_rstd : (1 / at::sqrt(running_var + eps));
    const auto normalized_input = (input.transpose(0, 1) - padRight(mean, nullopt, input.dim())) * padRight(var, nullopt, input.dim());
    const auto expanded_grad_weight = normalized_input * grad_out.transpose(0, 1);
    grad_weight = expanded_grad_weight.sum(range(1, grad_out.dim()));
  }
  if (output_mask[0]) {
    const auto grad_normalized_input = weight.defined() ?
      grad_out.transpose(0, 1) * padRight(weight, nullopt, grad_out.dim()) : grad_out.transpose(0, 1);           // [B0, C, B, *]
    Tensor grad_normalized_input_value;
    optional<int64_t> grad_normalized_input_bdim;
    std::tie(grad_normalized_input_value, grad_normalized_input_bdim) =
        unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level);       // [B0, B, C, *]

    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
        grad_normalized_input_value, grad_normalized_input_bdim,
        input_value, input_bdim,
        running_mean_value, running_mean_bdim,
        running_var_value, running_var_bdim,
        save_mean_value, save_mean_bdim,
        save_rstd_value, save_rstd_bdim,
        training, eps);
    grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
  }
  return std::make_tuple(grad_input, grad_weight, grad_bias);
}