in nestedtensor/csrc/BinaryOps.cpp [241:300]
Tensor NestedTensor_mul_Tensor(const Tensor& self_, const Tensor& other_) {
Tensor self = self_;
Tensor other = other_;
if (is_nested_tensor_impl(self) && !is_nested_tensor_impl(other)) {
self = NestedTensor_contiguous(self);
int64_t self_dim = get_dim(self);
auto self_opt_sizes = get_opt_sizes(self);
#ifdef WITH_CUDA
if (self_dim == 4 && other.dim() == 4 &&
self_opt_sizes[0] &&
self_opt_sizes[1] &&
(*self_opt_sizes[1]) == other.size(1) &&
other.size(0) == 1 &&
other.size(2) == 1 &&
other.size(3) == 1 &&
self.dtype() == c10::ScalarType::Half &&
other.dtype() == c10::ScalarType::Half) {
other = other.contiguous();
at::Tensor self_buffer = get_buffer(self);
Tensor nt_sizes_ =
get_efficient_nested_size(self).sizes().to(torch::kInt32);
Tensor nt_sizes_1 = at::native::narrow(nt_sizes_, 1, 1, 1);
Tensor nt_sizes_2 = at::native::narrow(nt_sizes_, 1, 2, 1);
Tensor nt_sizes_all = nt_sizes_1 * nt_sizes_2;
std::vector<int> numbers;
for (int64_t i = 0; i < nt_sizes_all.size(0); i++) {
for (int64_t j = 0; j < *self_opt_sizes[1]; j++) {
numbers.push_back(nt_sizes_all[i].item<int>());
}
}
at::Tensor numbers_t = torch::tensor(numbers).to(torch::kInt32);
Tensor nt_sizes_cumsum =
at::cumsum(numbers_t, 0).to(torch::kInt32).reshape({-1});
TORCH_CHECK(nt_sizes_.dim() == 2, "NestedTensor metadata of unexpected dimension.")
Tensor nt_sizes = at::cat({torch::tensor({0}, torch::kInt32), nt_sizes_cumsum});
nt_sizes = nt_sizes.to(torch::kCUDA);
at::cuda::CUDAStream defaultStream = at::cuda::getDefaultCUDAStream();
at::Tensor result_buffer = self_buffer.clone();
c10::Half* self_ptr = self_buffer.data_ptr<c10::Half>();
c10::Half* other_ptr = other.data_ptr<c10::Half>();
c10::Half* result_ptr = result_buffer.data_ptr<c10::Half>();
nested_tensor::cuda::mul_scalar_kernelLauncher(
self_ptr,
other_ptr,
result_ptr,
(int)(*self_opt_sizes[0] * *self_opt_sizes[1]),
(int)(*self_opt_sizes[0]),
nt_sizes.data_ptr<int>(),
defaultStream);
return wrap_buffer(std::move(result_buffer), get_efficient_nested_size(self),
get_efficient_nested_stride(self));
}
#endif
}
std::tie(self, other) = _expand_other_as(self_, other_);
return map_nested_tensor(
[](Tensor s, Tensor o) {
return at::mul(s, o); }, self, other);
}