in torch_xla/csrc/convert_ops.cpp [54:86]
xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to, const Device* device) {
if (from == to) {
return op;
}
if (GetDeviceOrCurrent(device).hw_type != DeviceType::TPU) {
return xla::ConvertElementType(op, to);
}
switch (from) {
case xla::PrimitiveType::PRED:
case xla::PrimitiveType::S8:
case xla::PrimitiveType::U8:
case xla::PrimitiveType::S16:
case xla::PrimitiveType::U16:
case xla::PrimitiveType::S32:
case xla::PrimitiveType::U32:
case xla::PrimitiveType::BF16:
case xla::PrimitiveType::F32:
return xla::ConvertElementType(op, to);
case xla::PrimitiveType::S64:
case xla::PrimitiveType::U64: {
switch (to) {
case xla::PrimitiveType::PRED:
return ExplicitBooleanConvert(op, from);
default:
return xla::ConvertElementType(op, to);
}
break;
}
default:
XLA_ERROR() << "Unsupported XLA type " << from;
}
}