in functorch/csrc/BatchedFallback.cpp [280:418]
void batchedTensorForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
const auto arguments = torch::jit::last(stack, num_arguments);
TORCH_CHECK(areAllReturnsTensors(schema) && !areAnyArgumentsTensorList(schema),
"Batching rule not implemented for ", schema.operator_name(), ". ",
"We could not generate a fallback.");
if (std::none_of(arguments.begin(), arguments.end(), ivalueParticipatesInCurrentLevel)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
op.callBoxed(stack);
return;
}
if (isInplaceOp(schema)) {
batchedTensorInplaceForLoopFallback(op, stack);
return;
}
TORCH_CHECK(!schema.is_mutable() && !schema.hasAnyAliasInfo(),
"Batching rule not implemented for ", schema.operator_name(), "; ",
"the fallback path doesn't work on out= or view ops.");
TORCH_CHECK(num_returns >= 1,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support operations with no returns.");
warnFallback(schema, /*in_place*/false);
const auto arguments_begin = stack->size() - num_arguments;
// Figure out which arguments are BatchedTensor. Save them to a vector.
// For each BatchedTensor, also record what position of `arguments` they came from.
at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
VmapDimVector batched_tensor_inputs_position;
for (const auto idx : c10::irange(0, arguments.size())) {
const auto& ivalue = arguments[idx];
if (!ivalue.isTensor()) {
continue;
}
const auto& tensor = ivalue.toTensor();
if (!tensor.defined()) {
continue;
}
const auto* batched = maybeGetBatchedImpl(tensor);
if (!batched) {
continue;
}
batched_tensor_inputs.push_back(tensor);
batched_tensor_inputs_position.push_back(idx);
}
TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);
// MultiBatchVmapTransform the BatchedTensor arguments. This returns
// VmapPhysicalViews that contain all of the batch dimensions.
const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
batched_tensor_inputs);
// Compute the total number of batches
auto num_batch_dims = input_physical_views.front().numBatchDims();
auto some_sizes = input_physical_views.front().tensor().sizes();
auto batch_sizes = ArrayRef<int64_t>(some_sizes.begin(), some_sizes.begin() + num_batch_dims);
const auto num_batches = c10::multiply_integers(batch_sizes);
// Without a shape-checking API, we're unable to compute the correct shape of
// the output so we just error out.
TORCH_CHECK(num_batches > 0,
"Batching rule not implemented for ", schema.operator_name(), ". ",
"The fallback path does not support vmap over dims of size 0.");
// Strategy: For each batch, we are going to push slices (where applicable)
// of the arguments onto `stack`, call `op`, and store the result in
// `output_shards`.
//
// NOTE: [Output shards layout]
// Assume that the operator has three outputs: a, b, c.
// The layout of output_shards is as follows:
// [ a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3]
// This is so that we can call at::stack([a0...a3]), at::stack([b0...b3])
// more easily in the next step.
std::vector<Tensor> output_shards(num_batches * num_returns);
for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
auto index = computeIndex(linear_idx, batch_sizes);
auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
auto input_physical_views_iter = input_physical_views.begin();
for (const auto arg_idx : c10::irange(0, num_arguments)) {
// We assume that torch::jit::Stack is backed by vector<IValue> for
// simplicity. When that is not the case, this code should be updated.
const auto& argument = (*stack)[arguments_begin + arg_idx];
if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
|| (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
// argument isn't a BatchedTensor
torch::jit::push(stack, argument);
continue;
}
// argument is a BatchedTensor
TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
const auto& physical_view_for_argument = *input_physical_views_iter;
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
torch::jit::push(stack, physical_view_for_argument.tensor().index(index));
batched_tensor_inputs_pos_iter++;
input_physical_views_iter++;
}
// std::cout << "[Fallback]: ";
// at::dump_tensor((*stack)[stack->size() - 1].toTensor());
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
op.callBoxed(stack);
// Store the result into `output_shards`. See NOTE: [Output shards layout]
// to learn about the details of how we store the shards.
const auto returns = torch::jit::last(stack, num_returns);
for (const auto return_idx : c10::irange(0, returns.size())) {
output_shards[num_batches * return_idx + linear_idx] = returns[return_idx].toTensor();
}
torch::jit::drop(stack, num_returns);
}
// For each output Tensor, stack the shards of the tensor together to form a return
torch::jit::drop(stack, num_arguments);
auto output_shards_chunks = MatrixRef<Tensor>(output_shards, num_batches);
for (const auto return_idx : c10::irange(0, num_returns)) {
auto shards = output_shards_chunks[return_idx];
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto flat_output = safeStack(shards);
// See NOTE [vmap through backward and undefined grad]
if (!flat_output.defined()) {
torch::jit::push(stack, flat_output);
continue;
}
VmapDimVector output_sizes(batch_sizes);
output_sizes.insert(
output_sizes.end(),
flat_output.sizes().begin() + 1,
flat_output.sizes().end());
torch::jit::push(
stack,
input_physical_views.front().getPhysicalToLogicalMap().apply(flat_output.view(output_sizes)));
}
}