Tensor from_padded_tensor()

in nestedtensor/csrc/masking.cpp [433:495]


Tensor from_padded_tensor(Tensor padded, EfficientSizeNode target_size) {
  TORCH_CHECK(padded.dim() == target_size.dim(),
      "Target size has different dimension as input padded Tensor.");
#ifdef WITH_CUDA
  if (padded.dim() > 1 && padded.dim() < 5 &&
      get_is_contiguous(padded) && padded.is_cuda() &&
      padded.dtype() == torch::kFloat16) {
    Tensor target_offsets = batch_offsets_from_efficient_size(target_size);
    std::vector<int64_t> padded_sizes = padded.sizes().vec();
    Tensor padded_sizes_tensor = torch::tensor(padded_sizes);
    Tensor output = torch::empty({target_size.numel()}, padded.options());
    Tensor target_size_sizes = target_size.sizes();

    at::Tensor metadata = at::cat({target_size_sizes.reshape(-1), padded_sizes_tensor, target_offsets});
    metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true);

    std::vector<int64_t> split_sizes;
    split_sizes.push_back(target_size_sizes.numel());
    split_sizes.push_back(padded_sizes_tensor.numel());
    split_sizes.push_back(target_offsets.numel());

    std::vector<Tensor> split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0);

    target_size_sizes = split[0];
    padded_sizes_tensor = split[1];
    target_offsets = split[2];

    at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
    nested_tensor::cuda::remove_padding_kernelLauncher(
        padded.data_ptr<c10::Half>(),
        output.data_ptr<c10::Half>(),
        target_offsets.data_ptr<int>(),
        padded_sizes_tensor.data_ptr<int>(),
        target_size_sizes.data_ptr<int>(),
        padded.dim() - 1,
        padded.size(0),
        defaultStream);
    return wrap_buffer(std::move(output), target_size);
  }
#endif
  at::Tensor target_size_tensor = std::get<0>(at::max(target_size.sizes(), 0));
  std::vector<int64_t> target_size_vec(target_size_tensor.data_ptr<int64_t>(),
      target_size_tensor.data_ptr<int64_t>() + target_size_tensor.numel());
  std::vector<at::Tensor> masks;
  std::vector<at::Tensor> all_sizes = target_size.sizes().unbind();
  for (int64_t i = 0; i < all_sizes.size(); i++) {
    std::vector<int64_t> sizes_i(
        all_sizes[i].data_ptr<int64_t>(),
        all_sizes[i].data_ptr<int64_t>() + all_sizes[i].numel());
    at::Tensor mask_i = padded.new_full(
                                    IntArrayRef(sizes_i),
                                    true,
                                    torch::kByte,
                                    c10::nullopt,
                                    c10::nullopt,
                                    c10::nullopt);
    mask_i = pad_tensor_to_shape(mask_i, target_size_vec);
    masks.push_back(mask_i);
  }
  at::Tensor final_mask = at::stack(masks);
  at::Tensor new_buffer = padded.masked_select(final_mask);
  return wrap_buffer(std::move(new_buffer), target_size);
}