in src/tir/op/op.cc [142:235]
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator";
if (lhs.dtype() == rhs.dtype()) return;
BroadcastToMatchLanes(lhs, rhs);
BroadcastToMatchLanes(rhs, lhs);
DataType ltype = lhs.dtype();
DataType rtype = rhs.dtype();
ICHECK(ltype.is_scalable_vector() == rtype.is_scalable_vector())
<< "Can't match scalable and fixed length vectors";
bool lanes_match = false;
if (ltype.is_scalable_vector()) {
lanes_match = ltype.vscale_factor() == rtype.vscale_factor();
} else {
lanes_match = ltype.lanes() == rtype.lanes();
}
ICHECK(lanes_match) << "Cannot match type " << ltype << " vs " << rtype;
if (lhs.dtype() == rhs.dtype()) return;
ltype = lhs.dtype();
rtype = rhs.dtype();
// We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
// operators. This can be helpful for users to find potential type conversion problems. The
// following are exceptions:
if (ltype.is_float() && rtype.is_float()) {
// Given two dissimilar floats, cast the lower bit version to the higher bit version.
// E.g. fp16 + fp32 --> fp32 + fp32
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(ltype, rhs);
}
} else if (!ltype.is_float() &&
(rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->float when the other operand is a float
lhs = cast(rtype, lhs);
} else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_float()) {
// Cast int->float when the other operand is a float
rhs = cast(ltype, rhs);
} else if (!ltype.is_bfloat16() &&
(rtype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
// Cast int->bfloat16 when the other operand is a bfloat16
lhs = cast(rtype, lhs);
} else if ((ltype.is_bfloat16() ||
datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
!rtype.is_bfloat16()) {
// Cast int->bfloat16 when the other operand is a bfloat16
rhs = cast(ltype, rhs);
} else if (!ltype.is_float8() && rtype.is_float8()) {
// Cast int->float8 for lhs when rhs is a float8
lhs = cast(rtype, lhs);
} else if (ltype.is_float8() && !rtype.is_float8()) {
// Cast int->float8 for rhs when lhs is a float8
rhs = cast(ltype, rhs);
} else if (!ltype.is_float4() && rtype.is_float4()) {
// Cast int->float4 for lhs when rhs is a float4
lhs = cast(rtype, lhs);
} else if (ltype.is_float4() && !rtype.is_float4()) {
// Cast int->float4 for rhs when lhs is a float4
rhs = cast(ltype, rhs);
} else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
// Promote int to higher bits e.g. int8 + int16 --> int16 + int16
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else {
rhs = cast(ltype, rhs);
}
} else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
// Handle mixing signed and unsigned integers
if (ltype.bits() < rtype.bits()) {
lhs = cast(rtype, lhs);
} else if (ltype.bits() > rtype.bits()) {
rhs = cast(ltype, rhs);
} else {
// The width of signed and unsigned integers is same.
if (ltype.is_uint()) {
rhs = cast(ltype, rhs);
} else {
lhs = cast(rtype, lhs);
}
}
} else {
LOG(INFO) << lhs << " " << rhs;
LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
}
}