in torch_xla/csrc/tensor_util.cpp [552:607]
void PopulateTensorBuffer(const at::Tensor& tensor,
const xla::Shape& dest_shape, void* dest_buffer,
size_t dest_buffer_size, const Device& device) {
switch (tensor.type().scalarType()) {
case at::ScalarType::Double:
TensorToBufferSType<double>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Float:
TensorToBufferSType<float>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::BFloat16:
TensorToBufferSType<at::BFloat16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Half:
TensorToBufferSType<at::Half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Bool:
TensorToBufferSType<bool>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Byte:
TensorToBufferSType<uint8_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Char:
TensorToBufferSType<int8_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Short:
TensorToBufferSType<int16_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Int:
TensorToBufferSType<int32_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::Long:
TensorToBufferSType<int64_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::ComplexFloat:
TensorToBufferSType<c10::complex<float>>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case at::ScalarType::ComplexDouble:
TensorToBufferSType<c10::complex<double>>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
default:
XLA_ERROR() << "Tensor type not supported: " << tensor.type();
}
}