in torch_xla/csrc/tensor_util.cpp [923:952]
xla::PrimitiveType TensorTypeToRawXlaType(at::ScalarType scalar_type) {
switch (scalar_type) {
case at::ScalarType::Double:
return xla::PrimitiveType::F64;
case at::ScalarType::Float:
return xla::PrimitiveType::F32;
case at::ScalarType::BFloat16:
return xla::PrimitiveType::BF16;
case at::ScalarType::Half:
return xla::PrimitiveType::F16;
case at::ScalarType::Bool:
return xla::PrimitiveType::PRED;
case at::ScalarType::Byte:
return xla::PrimitiveType::U8;
case at::ScalarType::Char:
return xla::PrimitiveType::S8;
case at::ScalarType::Short:
return xla::PrimitiveType::S16;
case at::ScalarType::Int:
return xla::PrimitiveType::S32;
case at::ScalarType::Long:
return xla::PrimitiveType::S64;
case at::ScalarType::ComplexFloat:
return xla::PrimitiveType::C64;
case at::ScalarType::ComplexDouble:
return xla::PrimitiveType::C128;
default:
XLA_ERROR() << "Type not supported: " << scalar_type;
}
}