Tensor to_padded_tensor()

in nestedtensor/csrc/masking.cpp [497:581]


Tensor to_padded_tensor(Tensor nt, double padding) {
#ifdef WITH_CUDA
  if ((get_dim(nt) >= 2 && get_dim(nt) <= 4)) {
    nt = NestedTensor_contiguous(nt, c10::MemoryFormat::Contiguous);
    auto nt_opt_size = get_opt_sizes(nt);
    Tensor nt_buffer = get_buffer(nt);
    if (nt_buffer.is_cuda()) {
      auto esize = get_efficient_nested_size(nt);
      at::Tensor nt_sizes = esize.sizes();
      Tensor offsets = batch_offsets_from_efficient_size(esize);
      std::vector<int64_t> new_size = padded_size_from_efficient_size(esize);
      at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
      Tensor output = at::empty(IntArrayRef(new_size), nt_buffer.options());
      Tensor new_size_tensor = torch::tensor(new_size);

      int64_t input_dim = nt_sizes.size(1);
      int64_t batch_size = nt_sizes.size(0);
      at::Tensor metadata = at::cat({new_size_tensor, offsets, nt_sizes.reshape(-1)});
      metadata = metadata.to(at::Device(kCUDA), torch::kInt32, true, true);

      std::vector<int64_t> split_sizes;
      split_sizes.push_back(new_size_tensor.numel());
      split_sizes.push_back(offsets.numel());
      split_sizes.push_back(nt_sizes.numel());

      std::vector<Tensor> split = at::split_with_sizes(metadata, IntArrayRef(split_sizes), 0);

      new_size_tensor = split[0];
      offsets = split[1];
      nt_sizes = split[2];

      if (nt_buffer.dtype() == torch::kFloat16) {
        nested_tensor::cuda::add_padding_kernelLauncher(
            nt_buffer.data_ptr<c10::Half>(),
            output.data_ptr<c10::Half>(),
            (c10::Half)(padding),
            offsets.data_ptr<int>(),
            nt_sizes.data_ptr<int>(),
            input_dim,
            new_size_tensor.data_ptr<int>(),
            batch_size,
            defaultStream);
        return output;
      }
      if (nt_buffer.dtype() == torch::kFloat) {
        nested_tensor::cuda::add_padding_kernelLauncher(
            nt_buffer.data_ptr<float>(),
            output.data_ptr<float>(),
            (float)(padding),
            offsets.data_ptr<int>(),
            nt_sizes.data_ptr<int>(),
            input_dim,
            new_size_tensor.data_ptr<int>(),
            batch_size,
            defaultStream);
        return output;
      }
      TORCH_CHECK(false, "Input datatype ", nt_buffer.dtype(), " is not supported.");
    }
  }
#endif
  auto opt_sizes = get_opt_sizes(nt);
  if (opt_sizes.size() == 1 && *opt_sizes[0] == 1) {
    nt = NestedTensor_contiguous(nt);
    return get_buffer(nt);
  }
  auto max_size = get_max_size(nt);
  TensorNode structure = get_nested_tensor_structure(nt);
  if (structure.degree() == 0) {
    return torch::tensor({padding});
  }
  std::vector<Tensor> res_tensor;
  for (auto child : structure.unbind()) {
    at::Tensor tensor = child.payload();
    if (get_numel(tensor) == 0) {
      TORCH_CHECK(false, "Empty tensors are not yet supported.");
    }
    // Dont pad in case of a scalar
    if (get_dim(tensor) == 0) {
      res_tensor.push_back(tensor);
    }
    res_tensor.push_back(pad_tensor_to_shape(tensor, max_size, padding));
  }
  return at::stack(res_tensor);
}