in Sources/x10/xla_tensor/tensor_util.cpp [727:755]
at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return UseBF16() ? at::ScalarType::Float : at::ScalarType::BFloat16;
case xla::PrimitiveType::F16:
return UseF16() ? at::ScalarType::Float : at::ScalarType::Half;
case xla::PrimitiveType::F32:
return at::ScalarType::Float;
case xla::PrimitiveType::F64:
return at::ScalarType::Double;
case xla::PrimitiveType::PRED:
return at::ScalarType::Bool;
case xla::PrimitiveType::U8:
return at::ScalarType::Byte;
case xla::PrimitiveType::S8:
return at::ScalarType::Char;
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
return at::ScalarType::Short;
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
return at::ScalarType::Int;
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64:
return at::ScalarType::Long;
default:
XLA_ERROR() << "XLA type not supported: " << xla_type;
}
}