in torch_xla/csrc/helpers.cpp [439:484]
xla::PrimitiveType XlaHelpers::PromoteType(xla::PrimitiveType type1,
xla::PrimitiveType type2) {
if (type1 == type2) {
return type1;
}
xla::int64_t size1 = xla::ShapeUtil::ByteSizeOfPrimitiveType(type1);
xla::int64_t size2 = xla::ShapeUtil::ByteSizeOfPrimitiveType(type2);
if (xla::primitive_util::IsComplexType(type1)) {
return (!xla::primitive_util::IsComplexType(type2) || size1 >= size2)
? type1
: type2;
}
if (xla::primitive_util::IsComplexType(type2)) {
return type2;
}
if (xla::primitive_util::IsFloatingPointType(type1)) {
return (!xla::primitive_util::IsFloatingPointType(type2) || size1 >= size2)
? type1
: type2;
}
if (xla::primitive_util::IsFloatingPointType(type2) || size2 > size1) {
return type2;
}
if (xla::primitive_util::IsIntegralType(type1) &&
xla::primitive_util::IsIntegralType(type2)) {
if (size1 > size2) {
return type1;
}
if (size2 > size1) {
return type2;
}
// At this point, they are not the same type, they are both integers, and
// they have the same size. One of them must be unsigned and the other
// signed, convert to unsigned.
return xla::primitive_util::UnsignedIntegralTypeForBitWidth(
xla::primitive_util::BitWidth(type1));
}
if (type1 == xla::PrimitiveType::PRED) {
return type2;
}
if (type2 == xla::PrimitiveType::PRED) {
return type1;
}
// If nothing matches the above logic, first operand wins.
return type1;
}