in nestedtensor/csrc/masking.cpp [250:310]
std::tuple<Tensor, Tensor> to_tensor_mask(
Tensor nt,
c10::optional<int64_t> mask_dim) {
#ifdef WITH_CUDA
if (get_dim(nt) == 3 && get_is_contiguous(nt) && mask_dim && *mask_dim == 2) {
auto nt_opt_size = get_opt_sizes(nt);
Tensor nt_buffer = get_buffer(nt);
if (nt_opt_size[2] && nt_buffer.is_cuda()) {
Tensor nt_sizes_ =
get_efficient_nested_size(nt).sizes().to(torch::kInt32);
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
Tensor nt_sizes = at::native::narrow(nt_sizes_, 1, 0, 1);
int max_size_1 = nt_sizes.max().item<int>();
nt_sizes =
at::cumsum(nt_sizes, 0).to(torch::kInt32).reshape({-1});
nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes});
Tensor output = torch::zeros(
{*nt_opt_size[0], max_size_1, *nt_opt_size[2]}, nt_buffer.options());
nt_sizes = nt_sizes.to(torch::kCUDA);
Tensor output_mask = torch::zeros(
{*nt_opt_size[0], max_size_1}, nt_buffer.options());
output_mask = output_mask.to(torch::kInt32);
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
nested_tensor::cuda::add_padding_mask_kernelLauncher(
nt_buffer.data_ptr<float>(),
output.data_ptr<float>(),
output_mask.data_ptr<int>(),
nt_sizes.data_ptr<int>(),
*nt_opt_size[0],
output_mask.stride(0),
output.stride(0),
*nt_opt_size[2],
defaultStream);
return std::make_tuple(output, output_mask.to(torch::kBool));
}
}
#endif
TORCH_CHECK(
!mask_dim || *mask_dim <= get_dim(nt),
"Requested mask dimension ",
*mask_dim,
" is bigger than dimension ",
get_dim(nt),
" of given NestedTensor.");
auto opt_sizes = get_opt_sizes(nt);
if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) {
nt = NestedTensor_contiguous(nt);
Tensor nt_buffer = get_buffer(nt);
nt_buffer = nt_buffer.reshape({-1});
Tensor result_mask = !mask_dim || *mask_dim == 0 ? torch::tensor(true)
: torch::tensor({true});
return std::make_tuple(nt_buffer, result_mask);
}
auto max_size = get_max_size(nt);
at::Tensor res_tensor;
at::Tensor res_mask;
std::tie(res_tensor, res_mask) = pad_nt(nt, max_size);
return merge_tensor_mask(res_tensor, res_mask, mask_dim);
}