Tensor NestedTensor_sub_Tensor()

in nestedtensor/csrc/BinaryOps.cpp [357:421]


Tensor NestedTensor_sub_Tensor(
    const Tensor& self_,
    const Tensor& other_,
    const Scalar& alpha) {
  Tensor self = self_;
  Tensor other = other_;
  if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
    self = NestedTensor_contiguous(self);
    int64_t self_dim = get_dim(self);
    auto self_opt_sizes = get_opt_sizes(self);
#ifdef WITH_CUDA
    if (self_dim == 4 && other.dim() == 4 &&
        self_opt_sizes[0] &&
        self_opt_sizes[1] &&
        (*self_opt_sizes[1]) == other.size(1) &&
        other.size(0) == 1 &&
        other.size(2) == 1 &&
        other.size(3) == 1 &&
        self.dtype() ==  c10::ScalarType::Half &&
        other.dtype() == c10::ScalarType::Half) {
      other = other.contiguous();
      at::Tensor self_buffer = get_buffer(self);
      Tensor nt_sizes_ =
          get_efficient_nested_size(self).sizes().to(torch::kInt32);
      Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
      Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
      Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
      std::vector<int> numbers;
      for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
        for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
          numbers.push_back(nt_sizes_all[i].item<int>());
        }
      }
      at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
      Tensor nt_sizes_cumsum =
          at::cumsum(numbers_t, 0).to(torch::kInt32).reshape({-1});
      TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
      Tensor nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes_cumsum});
      nt_sizes = nt_sizes.to(torch::kCUDA);
      at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
      at::Tensor result_buffer = self_buffer.clone();

      c10::Half* self_ptr = self_buffer.data_ptr<c10::Half>();
      c10::Half* other_ptr = other.data_ptr<c10::Half>();
      c10::Half* result_ptr = result_buffer.data_ptr<c10::Half>();
      nested_tensor::cuda::sub_scalar_kernelLauncher(
          self_ptr,
          other_ptr,
          result_ptr,
          (int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
          (int)(*self_opt_sizes[0]),
          nt_sizes.data_ptr<int>(),
          defaultStream);
      return wrap_buffer(std::move(result_buffer), get_efficient_nested_size(self),
          get_efficient_nested_stride(self));
    }
#endif
  }
  std::tie(self, other) = _expand_other_as(self_, other_);
  return map_nested_tensor(
      [&alpha](Tensor s, Tensor o) {
      return at::sub(s, o, alpha); },
      self,
      other);
}