in functorch/csrc/BatchRulesHelper.h [110:157]
void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
std::vector<int64_t> tensor_pos;
for (const auto idx : c10::irange(0, num_arguments)) {
const auto& ivalue = arguments[idx];
if (ivalue.isTensor()) {
Tensor tensor_value;
optional<int64_t> tensor_bdim;
std::tie(tensor_value, tensor_bdim) = unwrapTensorAtLevel(ivalue.toTensor(), cur_level);
tensor_inputs.push_back(std::make_pair(tensor_value, tensor_bdim));
tensor_pos.push_back(idx);
}
}
Func(tensor_inputs);
size_t tensor_idx = 0;
TORCH_INTERNAL_ASSERT(tensor_pos.size() > 0);
for (const auto arg_idx : c10::irange(0, num_arguments)) {
if (tensor_idx >= tensor_pos.size() || (int64_t)arg_idx != tensor_pos[tensor_idx]) {
torch::jit::push(stack, arguments[arg_idx]);
} else {
TORCH_INTERNAL_ASSERT(tensor_idx < tensor_inputs.size());
torch::jit::push(stack, tensor_inputs[tensor_idx].first);
tensor_idx++;
}
}
op.callBoxed(stack);
const auto returns = torch::jit::pop(*stack, num_returns);
for (const auto& ret : returns) {
if (ret.isTensor()) {
torch::jit::push(stack, makeBatched(ret.toTensor(), 0, cur_level));
} else {
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
}
}
}