void BinaryOpMatchTypes()

in src/tir/op/op.cc [142:235]


void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) {  // NOLINT(*)
  CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
  CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator";
  if (lhs.dtype() == rhs.dtype()) return;

  BroadcastToMatchLanes(lhs, rhs);
  BroadcastToMatchLanes(rhs, lhs);

  DataType ltype = lhs.dtype();
  DataType rtype = rhs.dtype();

  ICHECK(ltype.is_scalable_vector() == rtype.is_scalable_vector())
      << "Can't match scalable and fixed length vectors";

  bool lanes_match = false;

  if (ltype.is_scalable_vector()) {
    lanes_match = ltype.vscale_factor() == rtype.vscale_factor();
  } else {
    lanes_match = ltype.lanes() == rtype.lanes();
  }

  ICHECK(lanes_match) << "Cannot match type " << ltype << " vs " << rtype;
  if (lhs.dtype() == rhs.dtype()) return;

  ltype = lhs.dtype();
  rtype = rhs.dtype();
  // We keep dtypes conversion to be relatively consistent to reduce the amount code generated by
  // operators. This can be helpful for users to find potential type conversion problems. The
  // following are exceptions:
  if (ltype.is_float() && rtype.is_float()) {
    // Given two dissimilar floats, cast the lower bit version to the higher bit version.
    // E.g. fp16 + fp32 --> fp32 + fp32
    if (ltype.bits() < rtype.bits()) {
      lhs = cast(rtype, lhs);
    } else {
      rhs = cast(ltype, rhs);
    }
  } else if (!ltype.is_float() &&
             (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
    // Cast int->float when the other operand is a float
    lhs = cast(rtype, lhs);
  } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
             !rtype.is_float()) {
    // Cast int->float when the other operand is a float
    rhs = cast(ltype, rhs);
  } else if (!ltype.is_bfloat16() &&
             (rtype.is_bfloat16() ||
              datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) {
    // Cast int->bfloat16 when the other operand is a bfloat16
    lhs = cast(rtype, lhs);
  } else if ((ltype.is_bfloat16() ||
              datatype::Registry::Global()->GetTypeRegistered(ltype.code())) &&
             !rtype.is_bfloat16()) {
    // Cast int->bfloat16 when the other operand is a bfloat16
    rhs = cast(ltype, rhs);
  } else if (!ltype.is_float8() && rtype.is_float8()) {
    // Cast int->float8 for lhs when rhs is a float8
    lhs = cast(rtype, lhs);
  } else if (ltype.is_float8() && !rtype.is_float8()) {
    // Cast int->float8 for rhs when lhs is a float8
    rhs = cast(ltype, rhs);
  } else if (!ltype.is_float4() && rtype.is_float4()) {
    // Cast int->float4 for lhs when rhs is a float4
    lhs = cast(rtype, lhs);
  } else if (ltype.is_float4() && !rtype.is_float4()) {
    // Cast int->float4 for rhs when lhs is a float4
    rhs = cast(ltype, rhs);
  } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && rtype.is_uint())) {
    // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
    if (ltype.bits() < rtype.bits()) {
      lhs = cast(rtype, lhs);
    } else {
      rhs = cast(ltype, rhs);
    }
  } else if ((ltype.is_int() && rtype.is_uint()) || (ltype.is_uint() && rtype.is_int())) {
    // Handle mixing signed and unsigned integers
    if (ltype.bits() < rtype.bits()) {
      lhs = cast(rtype, lhs);
    } else if (ltype.bits() > rtype.bits()) {
      rhs = cast(ltype, rhs);
    } else {
      // The width of signed and unsigned integers is same.
      if (ltype.is_uint()) {
        rhs = cast(ltype, rhs);
      } else {
        lhs = cast(rtype, lhs);
      }
    }
  } else {
    LOG(INFO) << lhs << " " << rhs;
    LOG(FATAL) << "Cannot match type " << ltype << " vs " << rtype;
  }
}