in functorch/csrc/BatchRulesReduceOps.cpp [68:158]
void boxed_reduction_batch_rule(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
const auto& schema = op.schema();
const auto num_returns = schema.returns().size();
const auto num_arguments = schema.arguments().size();
auto arguments = torch::jit::pop(*stack, num_arguments);
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
std::vector<std::pair<Tensor, optional<int64_t>>> tensor_inputs;
std::vector<int64_t> tensor_pos;
TORCH_INTERNAL_ASSERT(arguments[0].isTensor());
Tensor self;
optional<int64_t> self_bdim;
std::tie(self, self_bdim) = unwrapTensorAtLevel(arguments[0].toTensor(), cur_level);
self = moveBatchDimToFront(self, self_bdim);
auto logical_dim = rankWithoutBatchDim(self, self_bdim);
std::vector<int64_t> dims;
ReductionCase reduction_case;
if (arguments[dim_arg_pos].isIntList()) {
reduction_case = ReductionCase::DimArray;
dims = arguments[dim_arg_pos].toIntList().vec();
if (dims.size() == 0) {
auto all_dims = range(0, std::max((int64_t)1, logical_dim));
dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
}
} else if (arguments[dim_arg_pos].isInt()) {
reduction_case = ReductionCase::Dim;
dims = {arguments[dim_arg_pos].toInt()};
} else if (arguments[dim_arg_pos].isNone()) {
auto param_type = schema.arguments()[dim_arg_pos].type()->expect<OptionalType>()->getElementType();
if (param_type->kind() == IntType::Kind) {
reduction_case = ReductionCase::Dim;
if (self.dim() > 1) {
self = self.flatten(1);
}
dims = {0};
} else if (param_type->kind() == ListType::Kind) {
reduction_case = ReductionCase::DimArray;
if (logical_dim == 0) {
dims = {0};
} else {
auto all_dims = range(0, self.dim() - 1);
dims = std::vector<int64_t>(all_dims.begin(), all_dims.end());
}
} else {
TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
}
} else{
TORCH_INTERNAL_ASSERT(false, "Unexpected dtype found at dims");
}
VmapDimVector new_dims;
new_dims.reserve(dims.size());
for (auto dim: dims) {
new_dims.push_back(getPhysicalDim(self, self_bdim.has_value(), dim));
}
bool is_scalar_case = logical_dim == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0]);
if (is_scalar_case) {
self = self.unsqueeze(-1);
new_dims = {1};
}
arguments[0] = self;
if (reduction_case == ReductionCase::DimArray) {
arguments[dim_arg_pos] = std::vector<int64_t>(new_dims.begin(), new_dims.end());
} else if (reduction_case == ReductionCase::Dim) {
arguments[dim_arg_pos] = new_dims[0];
}
for (const auto arg_idx : c10::irange(0, num_arguments)) {
torch::jit::push(stack, arguments[arg_idx]);
}
op.callBoxed(stack);
const auto returns = torch::jit::pop(*stack, num_returns);
for (const auto& ret : returns) {
if (ret.isTensor()) {
auto res = ret.toTensor();
if (is_scalar_case) {
res = res.squeeze(-1);
}
torch::jit::push(stack, makeBatched(res, 0, cur_level));
} else {
TORCH_INTERNAL_ASSERT(false, "This boxed batching rule does not currently support ops that return non-tensor values");
}
}
}