inline void boxed_all_tensors_have_optional_bdim()

in functorch/csrc/BatchRulesHelper.h [269:332]


inline void boxed_all_tensors_have_optional_bdim(
    const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  const auto& schema = op.schema();
  const auto num_returns = schema.returns().size();
  const auto num_arguments = schema.arguments().size();

  c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
  auto maybe_layer = maybeCurrentDynamicLayer();
  TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
  int64_t cur_level = maybe_layer->layerId();

  int64_t args_begin = stack->size() - num_arguments;
  SmallVector<UnpackedBatchedTensor, 5> tensor_inputs;
  SmallVector<int64_t, 5> tensor_pos;
  int64_t batch_size;

  find_and_unpack_tensors(
      stack, num_arguments, cur_level,
      &tensor_inputs, &tensor_pos, &batch_size);

  optional<bool> is_no_batch_dim_case;

  for (const auto tensor_idx : c10::irange(0, tensor_inputs.size())) {
    const auto& value = std::get<0>(tensor_inputs[tensor_idx]);
    auto bdim = std::get<1>(tensor_inputs[tensor_idx]);
    const auto logical_rank = rankWithoutBatchDim(value, bdim);

    if (!is_no_batch_dim_case.has_value()) {
      is_no_batch_dim_case = (logical_rank == feature_rank);
    }
    auto value_ = ensure_has_bdim(value, bdim.has_value(), batch_size);
    if (!bdim.has_value()) {
      bdim = 0;
    }
    if (*is_no_batch_dim_case) {
      TORCH_INTERNAL_ASSERT(logical_rank == feature_rank);
      value_ = moveBatchDimToFront(value_, bdim);
      if (tensor_idx == contig_tensor_index) {
        value_ = value_.contiguous();
      }
      (*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
      continue;
    }
    TORCH_INTERNAL_ASSERT(logical_rank == feature_rank + 1);
    value_ = reshape_dim_into(*bdim, 0, value_);
    if (tensor_idx == contig_tensor_index) {
      value_ = value_.contiguous();
    }
    (*stack)[args_begin + tensor_pos[tensor_idx]] = value_;
  }

  op.callBoxed(stack);

  for (const auto idx : c10::irange(args_begin, args_begin + num_returns)) {
    const auto& ret = (*stack)[idx];
    TORCH_INTERNAL_ASSERT(ret.isTensor(),
        "This boxed batching rule does not currently support ops that return non-tensor values");
    if (*is_no_batch_dim_case) {
      (*stack)[idx] = makeBatched(ret.toTensor(), 0, cur_level);
    } else {
      (*stack)[idx] = makeBatched(reshape_dim_outof(0, batch_size, ret.toTensor()), 0, cur_level);
    }
  }
}