void batchedTensorInplaceForLoopFallback()

in functorch/csrc/BatchedFallback.cpp [98:215]


void batchedTensorInplaceForLoopFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  const auto& schema = op.schema();
  warnFallback(schema, /*in_place*/true);

  const auto num_arguments = schema.arguments().size();
  const auto arguments = torch::jit::last(stack, num_arguments);
  const auto arguments_begin = stack->size() - num_arguments;

  // `self` is the Tensor being modified in-place
  Tensor self = arguments[0].toTensor();
  const auto* self_impl = maybeGetBatchedImpl(self);
  std::bitset<kVmapMaxTensorDims> self_vmap_levels;
  if (self_impl) {
    self_vmap_levels = createVmapLevelsBitset(self_impl->level());
  }

  // Figure out which arguments are BatchedTensor. Save them to a vector.
  // For each BatchedTensor, also record what position of `arguments` they came from.
  at::SmallVector<Tensor,kVmapTransformStaticInputSize> batched_tensor_inputs;
  VmapDimVector batched_tensor_inputs_position;
  for (const auto idx : c10::irange(0, arguments.size())) {
    const auto& ivalue = arguments[idx];
    if (!ivalue.isTensor()) {
      continue;
    }
    const auto& tensor = ivalue.toTensor();
    if (!tensor.defined()) {
      continue;
    }
    const auto* batched = maybeGetBatchedImpl(tensor);
    if (!batched) {
      continue;
    }

    // NOTE: [vmap-incompatible in-place operations]
    // In-place operations on `self` are not possible if there exists some vmap
    // level `l` such that `self` is not being vmapped on that level but another
    // argument is. For example, let B0 be a batch dim inside vmap and consider
    // vmap(Tensor.add_, in_dims=(None, 0))(torch.ones(3), torch.ones(B0, 3))
    // - self is torch.ones(3) and does not participate in this vmap
    // - other is BatchedTensor(torch.ones(B0, 3))
    // There's no way to do self.add_(other) because `other` has more elements
    // elements than `self` due to being vmapped over.
    //
    // In the vmap fallback, we should error out when we detect this.
    auto other_vmap_levels = createVmapLevelsBitset(batched->level());
    if (self_vmap_levels != (self_vmap_levels | other_vmap_levels)) {
      // Find one vmap level to complain about
      auto additional_bdims = (self_vmap_levels | other_vmap_levels) ^ self_vmap_levels;
      auto offending_level = llvm::findLastSet(additional_bdims.to_ulong());
      // The following prints out "vmap: aten::add_(tensor, ...) is not possible",
      // but it would be better to print out "tensor.add_(...) is not possible".
      // Afaict there's no official way to get the add_ and there is no way to
      // tell if an operator has method or function variants.
      TORCH_CHECK(false,
        "vmap: ", schema.name(), "(self, *extra_args) is not possible because ",
        "there exists a Tensor `other` in extra_args that has more elements ",
        "than `self`. This happened due to `other` being vmapped over but ",
        "`self` not being vmapped over at level ", offending_level, ". ",
        "Please try to use out-of-place operators instead of ", schema.name(), ". ",
        "If said operator is being called inside the PyTorch framework, ",
        "please file a bug report instead.");
    }
    batched_tensor_inputs.push_back(tensor);
    batched_tensor_inputs_position.push_back(idx);
  }
  TORCH_INTERNAL_ASSERT(batched_tensor_inputs.size() > 0);

  // MultiBatchVmapTransform the BatchedTensor arguments. This returns
  // VmapPhysicalViews that contain all of the batch dimensions.
  const auto input_physical_views = MultiBatchVmapTransform::logicalToPhysical(
      batched_tensor_inputs);

  // Compute the total number of batches
  auto num_batch_dims = input_physical_views.front().numBatchDims();
  auto first_physical_view_sizes = input_physical_views.front().tensor().sizes();
  auto batch_sizes = ArrayRef<int64_t>(
      first_physical_view_sizes.begin(), first_physical_view_sizes.begin() + num_batch_dims);
  const auto num_batches = c10::multiply_integers(batch_sizes);
  // Without a shape-checking API, we're unable to compute the correct shape of
  // the output so we just error out.
  TORCH_CHECK(num_batches > 0,
      "Batching rule not implemented for ", schema.operator_name(), ". ",
      "The fallback path does not support vmap over dims of size 0.");

  // Strategy: For each batch, we are going to push slices (where applicable)
  // of the arguments onto `stack`, and call `op`.
  for (int64_t linear_idx = 0; linear_idx < num_batches; ++linear_idx) {
    auto index = computeIndex(linear_idx, batch_sizes);
    auto batched_tensor_inputs_pos_iter = batched_tensor_inputs_position.begin();
    auto input_physical_views_iter = input_physical_views.begin();
    for (const auto arg_idx : c10::irange(0, num_arguments)) {
      // We assume that torch::jit::Stack is backed by vector<IValue> for
      // simplicity. When that is not the case, this code should be updated.
      const auto& argument = (*stack)[arguments_begin + arg_idx];
      if (batched_tensor_inputs_pos_iter == batched_tensor_inputs_position.end()
          || (int64_t)arg_idx != *batched_tensor_inputs_pos_iter) {
        // argument isn't a BatchedTensor
        torch::jit::push(stack, argument);
        continue;
      }
      // argument is a BatchedTensor
      TORCH_INTERNAL_ASSERT(input_physical_views_iter != input_physical_views.end());
      const auto& physical_view_for_argument = *input_physical_views_iter;
      auto thing = physical_view_for_argument.tensor().index(index);
      torch::jit::push(stack, thing);
      batched_tensor_inputs_pos_iter++;
      input_physical_views_iter++;
    }

    op.callBoxed(stack);
    torch::jit::drop(stack, 1);
  }

  // Return the tensor that was written to in-place
  torch::jit::drop(stack, num_arguments);
  torch::jit::push(stack, self);
}