std::tuple to_tensor_mask()

in nestedtensor/csrc/masking.cpp [250:310]


std::tuple<Tensor, Tensor> to_tensor_mask(
    Tensor nt,
    c10::optional<int64_t> mask_dim) {
#ifdef WITH_CUDA
  if (get_dim(nt) == 3 && get_is_contiguous(nt) && mask_dim && *mask_dim == 2) {
    auto nt_opt_size = get_opt_sizes(nt);
    Tensor nt_buffer = get_buffer(nt);
    if (nt_opt_size[2] && nt_buffer.is_cuda()) {
      Tensor nt_sizes_ =
          get_efficient_nested_size(nt).sizes().to(torch::kInt32);
      TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
      Tensor nt_sizes = at::native::narrow(nt_sizes_, 1, 0, 1);
      int max_size_1 = nt_sizes.max().item<int>();
      nt_sizes =
          at::cumsum(nt_sizes, 0).to(torch::kInt32).reshape({-1});
      nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes});
      Tensor output = torch::zeros(
          {*nt_opt_size[0], max_size_1, *nt_opt_size[2]}, nt_buffer.options());
      nt_sizes = nt_sizes.to(torch::kCUDA);
      Tensor output_mask = torch::zeros(
          {*nt_opt_size[0], max_size_1}, nt_buffer.options());
      output_mask = output_mask.to(torch::kInt32);
      at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
      nested_tensor::cuda::add_padding_mask_kernelLauncher(
          nt_buffer.data_ptr<float>(),
          output.data_ptr<float>(),
          output_mask.data_ptr<int>(),
          nt_sizes.data_ptr<int>(),
          *nt_opt_size[0],
          output_mask.stride(0),
          output.stride(0),
          *nt_opt_size[2],
          defaultStream);
      return std::make_tuple(output, output_mask.to(torch::kBool));
    }
  }
#endif
  TORCH_CHECK(
      !mask_dim || *mask_dim <= get_dim(nt),
      "Requested mask dimension ",
      *mask_dim,
      " is bigger than dimension ",
      get_dim(nt),
      " of given NestedTensor.");

  auto opt_sizes = get_opt_sizes(nt);
  if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) {
    nt = NestedTensor_contiguous(nt);
    Tensor nt_buffer = get_buffer(nt);
    nt_buffer = nt_buffer.reshape({-1});
    Tensor result_mask = !mask_dim || *mask_dim == 0 ? torch::tensor(true)
                                                     : torch::tensor({true});
    return std::make_tuple(nt_buffer, result_mask);
  }

  auto max_size = get_max_size(nt);
  at::Tensor res_tensor;
  at::Tensor res_mask;
  std::tie(res_tensor, res_mask) = pad_nt(nt, max_size);
  return merge_tensor_mask(res_tensor, res_mask, mask_dim);
}