in functorch/csrc/BatchRulesHelper.h [269:332]
inline void boxed_all_tensors_have_optional_bdim(
const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
int64_t args_begin = stack->size() - num_arguments;
SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
SmallVector<int64_t, 5> tensor_pos;
int64_t batch_size;
find_and_unpack_tensors(
stack, num_arguments, cur_level,
&tensor_inputs, &tensor_pos, &batch_size);
optional<bool> is_no_batch_dim_case;
for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
const auto logical_rank = rankWithoutBatchDim(value, bdim);
if (!is_no_batch_dim_case.has_value()) {
is_no_batch_dim_case = (logical_rank == feature_rank);
}
auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
if (!bdim.has_value()) {
bdim = 0;
}
if (*is_no_batch_dim_case) {
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
value_ = moveBatchDimToFront(value_, bdim);
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
continue;
}
TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
value_ = reshape_dim_into(*bdim, 0, value_);
if (tensor_idx == contig_tensor_index) {
value_ = value_.contiguous();
}
(*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
}
op.callBoxed(stack);
for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
const auto& ret = (*stack)[idx];
TORCH_INTERNAL_ASSERT(ret.isTensor(),
"This boxed batching rule does not currently support ops that return non-tensor values");
if (*is_no_batch_dim_case) {
(*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
} else {
(*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
}
}
}