in torch_xla/csrc/tensor_util.cpp [954:992]
xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
const Device* device) {
Device xla_device = GetDeviceOrCurrent(device);
switch (type) {
case xla::PrimitiveType::F64:
if (UseF16()) {
return xla::PrimitiveType::F16;
}
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
if (DowncastBF16() || DowncastF16()) {
return xla::PrimitiveType::F32;
}
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
: xla::PrimitiveType::F32;
case xla::PrimitiveType::F32:
if (UseF16() || DowncastF16()) {
return xla::PrimitiveType::F16;
}
return UseBF16() || DowncastBF16() ? xla::PrimitiveType::BF16
: xla::PrimitiveType::F32;
case xla::PrimitiveType::U16:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U16
: xla::PrimitiveType::U32;
case xla::PrimitiveType::S16:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
case xla::PrimitiveType::U64:
return Use32BitLong() ? xla::PrimitiveType::U32 : xla::PrimitiveType::U64;
case xla::PrimitiveType::C128:
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::C128
: xla::PrimitiveType::C64;
default:
return type;
}
}