in torch_xla/csrc/tensor_util.cpp [994:1024]
xla::PrimitiveType MakeXlaPrimitiveType(at::ScalarType scalar_type,
const Device* device) {
switch (scalar_type) {
case at::ScalarType::Double:
return GetDevicePrimitiveType(xla::PrimitiveType::F64, device);
case at::ScalarType::Float:
return GetDevicePrimitiveType(xla::PrimitiveType::F32, device);
case at::ScalarType::BFloat16:
return GetDevicePrimitiveType(xla::PrimitiveType::BF16, device);
case at::ScalarType::Half:
return GetDevicePrimitiveType(xla::PrimitiveType::F16, device);
case at::ScalarType::Bool:
return GetDevicePrimitiveType(xla::PrimitiveType::PRED, device);
case at::ScalarType::Byte:
return GetDevicePrimitiveType(xla::PrimitiveType::U8, device);
case at::ScalarType::Char:
return GetDevicePrimitiveType(xla::PrimitiveType::S8, device);
case at::ScalarType::Short:
return GetDevicePrimitiveType(xla::PrimitiveType::S16, device);
case at::ScalarType::Int:
return GetDevicePrimitiveType(xla::PrimitiveType::S32, device);
case at::ScalarType::Long:
return GetDevicePrimitiveType(xla::PrimitiveType::S64, device);
case at::ScalarType::ComplexFloat:
return GetDevicePrimitiveType(xla::PrimitiveType::C64, device);
case at::ScalarType::ComplexDouble:
return GetDevicePrimitiveType(xla::PrimitiveType::C128, device);
default:
XLA_ERROR() << "Type not supported: " << scalar_type;
}
}