in torch_xla/csrc/tensor_util.cpp [886:921]
at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
switch (xla_type) {
case xla::PrimitiveType::BF16:
return UseBF16() || DowncastBF16() ? at::ScalarType::Float
: at::ScalarType::BFloat16;
case xla::PrimitiveType::F16:
return UseF16() || DowncastF16() ? at::ScalarType::Float
: at::ScalarType::Half;
case xla::PrimitiveType::F32:
return DowncastBF16() || DowncastF16() ? at::ScalarType::Double
: at::ScalarType::Float;
case xla::PrimitiveType::F64:
return at::ScalarType::Double;
case xla::PrimitiveType::PRED:
return at::ScalarType::Bool;
case xla::PrimitiveType::U8:
return at::ScalarType::Byte;
case xla::PrimitiveType::S8:
return at::ScalarType::Char;
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
return at::ScalarType::Short;
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
return at::ScalarType::Int;
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64:
return at::ScalarType::Long;
case xla::PrimitiveType::C64:
return at::ScalarType::ComplexFloat;
case xla::PrimitiveType::C128:
return at::ScalarType::ComplexDouble;
default:
XLA_ERROR() << "XLA type not supported: " << xla_type;
}
}