in functorch/csrc/BatchRulesConvolution.cpp [405:497]
std::tuple<Tensor,Tensor,Tensor> convolution_backward_plumbing(
const Tensor& grad_output_, const Tensor& input_, const Tensor& weight_,
const c10::optional<IntArrayRef> bias_sizes_opt,
IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, bool transposed,
IntArrayRef output_padding, int64_t groups, std::array<bool, 3> output_mask) {
const auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
Tensor grad_output;
optional<int64_t> grad_output_bdim;
std::tie(grad_output, grad_output_bdim) = unwrapTensorAtLevel(grad_output_, cur_level);
Tensor input;
optional<int64_t> input_bdim;
std::tie(input, input_bdim) = unwrapTensorAtLevel(input_, cur_level);
Tensor weight;
optional<int64_t> weight_bdim;
std::tie(weight, weight_bdim) = unwrapTensorAtLevel(weight_, cur_level);
const auto grad_bias = compute_grad_bias(grad_output_, output_mask);
output_mask[2] = false;
// TODO: A little bird says that unfold + matmul is actually faster than
// group convolution in many cases. We should benchmark some of
// the common cases and replace things with unfold + matmul as necessary.
// Notation:
// B - a batch dimension
// G - groups (sometimes omitted because it doesn't matter)
// NO - grad_output
// NI - input
// OI - weight
// "(BO)I" - we don't actually care about the values of this Tensor,
// we just need to create a tensor on the same device with the
// correct shape and pray that the implementation is smart enough
// to not do anything with it.
// BNO, BNI, BOI
// AKA one of the model ensembling case
if (grad_output_bdim && input_bdim && weight_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
grad_output = reshape_dim_into(*grad_output_bdim, 1, grad_output);
// BNO, BNI, BOI -> N(BO), N(BI), (BO)I
const auto batch_size = weight.size(*weight_bdim);
input = reshape_dim_into(*input_bdim, 1, input);
weight = reshape_dim_into(*weight_bdim, 0, weight);
const auto result = at::convolution_backward(
grad_output, input, weight, nullopt, stride, padding, dilation,
transposed, output_padding, batch_size * groups, output_mask);
// N(BI), (BO)I -> NBI, BOI
const auto grad_input = output_mask[0] ?
reshape_dim_outof(1, batch_size, std::get<0>(result)) : Tensor();
const auto grad_weight = output_mask[1] ?
reshape_dim_outof(0, batch_size, std::get<1>(result)) : Tensor();
return std::make_tuple(
output_mask[0] ? makeBatched(grad_input, 1, cur_level) : grad_input,
output_mask[1] ? makeBatched(grad_weight, 0, cur_level) : grad_weight,
grad_bias);
}
Tensor grad_input;
if (output_mask[0]) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto result = convolution_backward_input_batch_rule(
grad_output, grad_output_bdim,
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_input = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
}
Tensor grad_weight;
if (output_mask[1]) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto result = convolution_backward_weight_batch_rule(
grad_output, grad_output_bdim,
input, input_bdim,
weight, weight_bdim,
stride, padding, dilation, transposed, output_padding, groups);
grad_weight = makeBatched(std::get<0>(result), std::get<1>(result), cur_level);
}
return std::make_tuple(grad_input, grad_weight, grad_bias);
// Someone's definitely going to find a problem with this batching rule so
// I'm leaving the following fallback if we need it back.
// static auto op = c10::Dispatcher::singleton()
// .findSchemaOrThrow("aten::convolution_backward", "");
// auto result = slow_fallback<Tensor,Tensor,Tensor>(op, {
// grad_output_, input_, weight_, bias_sizes_opt,
// stride, padding, dilation, transposed, output_padding, groups, output_mask
// });
// return std::make_tuple(grad_input, std::get<1>(result), grad_bias);
}