std::tuple convolution_backward_plumbing()

in functorch/csrc/BatchRulesConvolution.cpp [405:497]


std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
    const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
    const c10::optional<IntArrayRef> bias_sizes_opt,
    IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
    IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
  const auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();
  Tensor grad_output;
  optional<int64_t> grad_output_bdim;
  std::tie(grad_output, grad_output_bdim) = unwrapTensorAtLevel(grad_output_, cur_level);
  Tensor input;
  optional<int64_t> input_bdim;
  std::tie(input, input_bdim) = unwrapTensorAtLevel(input_, cur_level);
  Tensor weight;
  optional<int64_t> weight_bdim;
  std::tie(weight, weight_bdim) = unwrapTensorAtLevel(weight_, cur_level);

  const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
  output_mask[2] = false;

  // TODO: A little bird says that unfold + matmul is actually faster than
  // group convolution in many cases. We should benchmark some of
  // the common cases and replace things with unfold + matmul as necessary.

  // Notation:
  // B - a batch dimension
  // G - groups (sometimes omitted because it doesn't matter)
  // NO - grad_output
  // NI - input
  // OI - weight
  // "(BO)I" - we don't actually care about the values of this Tensor,
  //           we just need to create a tensor on the same device with the
  //           correct shape and pray that the implementation is smart enough
  //           to not do anything with it.

  // BNO, BNI, BOI
  // AKA one of the model ensembling case
  if (grad_output_bdim && input_bdim && weight_bdim) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);

    // BNO, BNI, BOI -> N(BO), N(BI), (BO)I
    const auto batch_size = weight.size(*weight_bdim);
    input = reshape_dim_into(*input_bdim, 1, input);
    weight = reshape_dim_into(*weight_bdim, 0, weight);
    const auto result = at::convolution_backward(
        grad_output, input, weight, nullopt, stride, padding, dilation,
        transposed, output_padding, batch_size * groups, output_mask);
    // N(BI), (BO)I -> NBI, BOI
    const auto grad_input = output_mask[0] ?
      reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
    const auto grad_weight = output_mask[1] ?
      reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
    return std::make_tuple(
        output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
        output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
        grad_bias);
  }

  Tensor grad_input;
  if (output_mask[0]) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    const auto result = convolution_backward_input_batch_rule(
        grad_output, grad_output_bdim,
        input, input_bdim,
        weight, weight_bdim,
        stride, padding, dilation, transposed, output_padding, groups);
    grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
  }

  Tensor grad_weight;
  if (output_mask[1]) {
    c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
    const auto result = convolution_backward_weight_batch_rule(
        grad_output, grad_output_bdim,
        input, input_bdim,
        weight, weight_bdim,
        stride, padding, dilation, transposed, output_padding, groups);
    grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
  }
  return std::make_tuple(grad_input, grad_weight, grad_bias);

  // Someone's definitely going to find a problem with this batching rule so
  // I'm leaving the following fallback if we need it back.
  // static auto op = c10::Dispatcher::singleton()
  //   .findSchemaOrThrow("aten::convolution_backward", "");
  // auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
  //   grad_output_, input_, weight_, bias_sizes_opt,
  //   stride, padding, dilation, transposed, output_padding, groups, output_mask
  // });
  // return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
}