def _register_expr_op()

in python/tvm/script/parser/tir/operation.py [0:0]


def _register_expr_op(ty: Type):  # pylint: disable=invalid-name
    ty._dispatch_type = ty  # pylint: disable=protected-access

    def _and(a, b):
        if isinstance(a, bool):
            a = IntImm("bool", a)
        if isinstance(b, bool):
            b = IntImm("bool", b)
        if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
            return a & b
        else:
            return tir.And(a, b)

    def _or(a, b):
        if isinstance(a, bool):
            a = IntImm("bool", a)
        if isinstance(b, bool):
            b = IntImm("bool", b)
        if DataType(a.dtype).lanes > 1 or DataType(b.dtype).lanes > 1:
            return a | b
        else:
            return tir.Or(a, b)

    def _get_type_str(dtype: str):
        if DataType(dtype).lanes == 1:
            return dtype
        index = dtype.find("x")
        return dtype[0:index]

    def _auto_broadcast(a, b, op):
        if isinstance(a, int):
            if hasattr(b, "dtype"):
                if (
                    DataType(b.dtype).type_code == DataTypeCode.INT
                    or DataType(b.dtype).type_code == DataTypeCode.UINT
                ):
                    a = IntImm(_get_type_str(b.dtype), a)
                elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
                    a = FloatImm(_get_type_str(b.dtype), a)
            elif isinstance(b, float):
                a = FloatImm("float32", a)
            else:
                a = IntImm("int32", a)
        elif isinstance(a, float):
            if DataType(b.dtype).type_code == DataTypeCode.FLOAT:
                a = FloatImm(_get_type_str(b.dtype), a)
            else:
                a = FloatImm("float32", a)

        assert isinstance(a, tir.PrimExpr), "Operand should be a PrimExpr."
        if isinstance(b, int):
            if (
                DataType(a.dtype).type_code == DataTypeCode.INT
                or DataType(a.dtype).type_code == DataTypeCode.UINT
            ):
                b = IntImm(_get_type_str(a.dtype), b)
            elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
                b = FloatImm(_get_type_str(a.dtype), b)
        elif isinstance(b, float):
            b = FloatImm(_get_type_str(a.dtype), b)

        if DataType(a.dtype).lanes == DataType(b.dtype).lanes:
            return op(a, b)
        elif DataType(a.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
            broadcast_a = tir.Broadcast(a, DataType(b.dtype).lanes)
            return op(broadcast_a, b)
        elif DataType(b.dtype).lanes == 1 and DataType(a.dtype).lanes != DataType(b.dtype).lanes:
            broadcast_b = tir.Broadcast(b, DataType(a.dtype).lanes)
            return op(a, broadcast_b)
        else:
            raise TypeError("do not know how to deal with it.")

    def _eq(a, b):
        return _auto_broadcast(a, b, tir.EQ)

    def _ne(a, b):
        return _auto_broadcast(a, b, tir.NE)

    def _lt(a, b):
        return _auto_broadcast(a, b, tir.LT)

    def _le(a, b):
        return _auto_broadcast(a, b, tir.LE)

    def _gt(a, b):
        return _auto_broadcast(a, b, tir.GT)

    def _ge(a, b):
        return _auto_broadcast(a, b, tir.GE)

    def r(op: Type, i: int, m: OpMethod):  # pylint: disable=invalid-name
        register_op(ty, op, i)(m)

    for i in [0, 1]:
        # Case 1. binop
        # doc.Add <-- is overloaded
        # doc.Sub <-- is overloaded
        # doc.Mult <-- is overloaded
        # doc.Div <-- is overloaded
        # doc.FloorDiv <-- is overloaded
        # doc.Mod <-- is overloaded
        # doc.LShift <-- is overloaded
        # doc.RShift <-- is overloaded
        # doc.BitOr <-- is overloaded
        # doc.BitXor <-- is overloaded
        # doc.BitAnd <-- is overloaded
        # doc.MatMult <-- not implemented
        # doc.Pow <-- not implemented
        # Case 2. cmpop
        r(doc.Eq, i, _eq)
        r(doc.NotEq, i, _ne)
        r(doc.Lt, i, _lt)
        r(doc.LtE, i, _le)
        r(doc.Gt, i, _gt)
        r(doc.GtE, i, _ge)
        # doc.Is <-- not implemented
        # doc.IsNot <-- not implemented
        # doc.In <-- not implemented
        # doc.NotIn <-- not implemented
        # Case 3. boolop
        r(doc.And, i, _and)
        r(doc.Or, i, _or)
    for i in [0]:
        #  Case 4. unaryop
        # doc.Invert <-- is overloaded
        r(doc.Not, i, tir.Not)