in nestedtensor/csrc/masking.cpp [433:495]
Tensor from_padded_tensor(Tensor padded, EfficientSizeNode target_size) {
TORCH_CHECK(padded.dim() == target_size.dim(),
"Target size has different dimension as input padded Tensor.");
#ifdef WITH_CUDA
if (padded.dim() > 1 && padded.dim() < 5 &&
get_is_contiguous(padded) && padded.is_cuda() &&
padded.dtype() == torch::kFloat16) {
Tensor target_offsets = batch_offsets_from_efficient_size(target_size);
std::vector<int64_t> padded_sizes = padded.sizes().vec();
Tensor padded_sizes_tensor = torch::tensor(padded_sizes);
Tensor output = torch::empty({target_size.numel()}, padded.options());
Tensor target_size_sizes = target_size.sizes();
at::Tensor metadata = at::cat({target_size_sizes.reshape(-1), padded_sizes_tensor, target_offsets});
metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true);
std::vector<int64_t> split_sizes;
split_sizes.push_back(target_size_sizes.numel());
split_sizes.push_back(padded_sizes_tensor.numel());
split_sizes.push_back(target_offsets.numel());
std::vector<Tensor> split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0);
target_size_sizes = split[0];
padded_sizes_tensor = split[1];
target_offsets = split[2];
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
nested_tensor::cuda::remove_padding_kernelLauncher(
padded.data_ptr<c10::Half>(),
output.data_ptr<c10::Half>(),
target_offsets.data_ptr<int>(),
padded_sizes_tensor.data_ptr<int>(),
target_size_sizes.data_ptr<int>(),
padded.dim() - 1,
padded.size(0),
defaultStream);
return wrap_buffer(std::move(output), target_size);
}
#endif
at::Tensor target_size_tensor = std::get<0>(at::max(target_size.sizes(), 0));
std::vector<int64_t> target_size_vec(target_size_tensor.data_ptr<int64_t>(),
target_size_tensor.data_ptr<int64_t>() + target_size_tensor.numel());
std::vector<at::Tensor> masks;
std::vector<at::Tensor> all_sizes = target_size.sizes().unbind();
for (int64_t i = 0; i < all_sizes.size(); i++) {
std::vector<int64_t> sizes_i(
all_sizes[i].data_ptr<int64_t>(),
all_sizes[i].data_ptr<int64_t>() + all_sizes[i].numel());
at::Tensor mask_i = padded.new_full(
IntArrayRef(sizes_i),
true,
torch::kByte,
c10::nullopt,
c10::nullopt,
c10::nullopt);
mask_i = pad_tensor_to_shape(mask_i, target_size_vec);
masks.push_back(mask_i);
}
at::Tensor final_mask = at::stack(masks);
at::Tensor new_buffer = padded.masked_select(final_mask);
return wrap_buffer(std::move(new_buffer), target_size);
}