in Sources/x10/xla_tensor/tensor_util.cpp [784:820]
xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
const Device* device) {
Device xla_device = GetDeviceOrCurrent(device);
switch (type) {
case xla::PrimitiveType::F64:
if (UseF16()) {
return xla::PrimitiveType::F16;
}
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
: xla::PrimitiveType::F32;
case xla::PrimitiveType::F32:
// When S4TF will support native BF16 type, the global configuration can
// be replaced (or augmented) with the proper mapping.
if (UseF16()) {
return xla::PrimitiveType::F16;
}
return UseBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
default:
return type;
}
}