in torch_xla/csrc/tensor_util.cpp [483:550]
void TensorToBufferSType(const at::Tensor& tensor, const xla::Shape& dest_shape,
void* dest_buffer, size_t dest_buffer_size,
const Device& device) {
switch (dest_shape.element_type()) {
case xla::PrimitiveType::BF16:
TensorToBuffer<SType, tensorflow::bfloat16>(
tensor, dest_shape, dest_buffer, dest_buffer_size, device);
break;
case xla::PrimitiveType::F16:
TensorToBuffer<SType, xla::half>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F32:
TensorToBuffer<SType, float>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::F64:
TensorToBuffer<SType, double>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::PRED:
TensorToBuffer<SType, bool>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::U8:
TensorToBuffer<SType, xla::uint8>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::S8:
TensorToBuffer<SType, xla::int8>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::S16:
TensorToBuffer<SType, xla::int16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::U16:
TensorToBuffer<SType, xla::uint16>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::S32:
TensorToBuffer<SType, xla::int32>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::U32:
TensorToBuffer<SType, xla::uint32>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::S64:
TensorToBuffer<SType, xla::int64_t>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::U64:
TensorToBuffer<SType, xla::uint64>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::C64:
TensorToBuffer<SType, xla::complex64>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
case xla::PrimitiveType::C128:
TensorToBuffer<SType, xla::complex128>(tensor, dest_shape, dest_buffer,
dest_buffer_size, device);
break;
default:
XLA_ERROR() << "Destination shape type not supported: " << dest_shape;
}
}