def InjectALUIntrin()

in vta/python/vta/transform.py [0:0]


def InjectALUIntrin():
    """Pass to inject ALU micro-ops.

    Returns
    -------
    fpass : tvm.transform.Pass
        The pass
    """

    def _ftransform(func, mod, ctx):
        env = get_env()
        idxm = tvm.tir.indexmod
        analyzer = tvm.arith.Analyzer()

        def _do_fold(stmt):
            def _equal(x, y):
                return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)

            def _flatten_loop(src_coeff, dst_coeff, extents):
                src_coeff = list(src_coeff)
                dst_coeff = list(dst_coeff)
                extents = list(extents)
                rev_src_coeff = [src_coeff.pop()]
                rev_dst_coeff = [dst_coeff.pop()]
                rev_extents = []
                assert src_coeff
                vsrc = src_coeff.pop()
                vdst = dst_coeff.pop()
                vext = extents.pop()
                while src_coeff:
                    next_src = src_coeff.pop()
                    next_dst = dst_coeff.pop()
                    next_ext = extents.pop()

                    if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
                        vext = analyzer.simplify(vext * next_ext)
                    else:
                        rev_src_coeff.append(vsrc)
                        rev_dst_coeff.append(vdst)
                        rev_extents.append(vext)
                        vsrc = next_src
                        vdst = next_dst
                        vext = next_ext
                rev_src_coeff.append(vsrc)
                rev_dst_coeff.append(vdst)
                rev_extents.append(vext)
                rev_src_coeff.reverse()
                rev_dst_coeff.reverse()
                rev_extents.reverse()

                return rev_src_coeff, rev_dst_coeff, rev_extents

            if _match_pragma(stmt, "alu"):
                # Get to the innermost loop body
                loop_body = stmt.body
                nest_size = 0
                while isinstance(loop_body, tvm.tir.For):
                    loop_body = loop_body.body
                    nest_size += 1
                # Get the src/dst arguments
                dst_var = loop_body.buffer_var
                dst_idx = loop_body.index
                # Derive loop variables and extents
                tmp_body = stmt.body
                indices = []
                extents = []
                for _ in range(nest_size):
                    indices.append(tmp_body.loop_var)
                    extents.append(tmp_body.extent)
                    tmp_body = tmp_body.body
                # Derive opcode
                if isinstance(loop_body.value, tvm.tir.Add):
                    alu_opcode = env.dev.ALU_OPCODE_ADD
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Sub):
                    alu_opcode = env.dev.ALU_OPCODE_SUB
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Mul):
                    alu_opcode = env.dev.ALU_OPCODE_MUL
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Min):
                    alu_opcode = env.dev.ALU_OPCODE_MIN
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Max):
                    alu_opcode = env.dev.ALU_OPCODE_MAX
                    lhs = loop_body.value.a
                    rhs = loop_body.value.b
                elif isinstance(loop_body.value, tvm.tir.Call):
                    if loop_body.value.op.name == "tir.shift_left":
                        alu_opcode = env.dev.ALU_OPCODE_SHR
                        lhs = loop_body.value.args[0]
                        rhs = analyzer.simplify(-loop_body.value.args[1])
                    elif loop_body.value.op.name == "tir.shift_right":
                        alu_opcode = env.dev.ALU_OPCODE_SHR
                        lhs = loop_body.value.args[0]
                        rhs = loop_body.value.args[1]
                    else:
                        raise RuntimeError(
                            "Function call not recognized %s" % (loop_body.value.name)
                        )
                elif isinstance(loop_body.value, tvm.tir.Load):
                    alu_opcode = env.dev.ALU_OPCODE_SHR
                    lhs = loop_body.value
                    rhs = tvm.tir.const(0, "int32")
                else:
                    raise RuntimeError(
                        "Expression not recognized %s, %s, %s"
                        % (type(loop_body.value), str(loop_body.value), str(stmt))
                    )

                # Derive array index coefficients
                dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
                # Check if lhs/rhs is immediate
                use_imm = False
                imm_val = None
                if isinstance(rhs, tvm.tir.IntImm):
                    assert lhs.buffer_var.same_as(dst_var)
                    src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
                    use_imm = True
                    imm_val = rhs
                if isinstance(lhs, tvm.tir.IntImm):
                    assert rhs.buffer_var.same_as(dst_var)
                    src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
                    use_imm = True
                    imm_val = lhs
                if imm_val is None:
                    imm_val = 0
                    assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
                    src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
                    src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
                    # Determine which side has the same coefficients
                    lhs_equal = True
                    rhs_equal = True
                    for i, coef in enumerate(dst_coeff):
                        if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
                            lhs_equal = False
                        if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
                            rhs_equal = False
                    # Make sure at least one of the source is identical to the
                    # destination (in-place computation)
                    assert lhs_equal or rhs_equal
                    # Assign the source coefficients
                    if lhs_equal:
                        src_coeff = src_rhs_coeff
                    else:
                        src_coeff = src_lhs_coeff

                # Ensure that we have the proper tensor dimensions in the
                # innermost loop (pattern match)
                src_coeff = list(src_coeff)
                dst_coeff = list(dst_coeff)
                extents = list(extents)
                assert len(src_coeff) > 1
                assert len(dst_coeff) > 1
                assert len(extents) != 0
                assert tvm.ir.structural_equal(
                    analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
                )
                assert tvm.ir.structural_equal(
                    analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
                )
                assert tvm.ir.structural_equal(src_coeff[-2], 1)
                assert tvm.ir.structural_equal(dst_coeff[-2], 1)
                if env.BATCH > 1:
                    assert len(src_coeff) > 2
                    assert len(dst_coeff) > 2
                    assert len(extents) > 1
                    assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
                    assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)

                # Apply tensorization of the loop coefficients
                src_offset = src_coeff[-1]
                dst_offset = dst_coeff[-1]
                if env.BATCH == 1:
                    src_coeff = src_coeff[:-2]
                    dst_coeff = dst_coeff[:-2]
                    extents = extents[:-1]
                else:
                    src_coeff = src_coeff[:-3]
                    dst_coeff = dst_coeff[:-3]
                    extents = extents[:-2]
                src_coeff.append(src_offset)
                dst_coeff.append(dst_offset)
                src_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
                dst_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]

                # Flatten the outer loops
                if extents:
                    src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)

                # Insert ALU micro-ops
                irb = tvm.tir.ir_builder.create()
                for idx, extent in enumerate(extents):
                    irb.emit(
                        tvm.tir.call_extern(
                            "int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0
                        )
                    )
                use_imm = int(use_imm)
                irb.emit(
                    tvm.tir.call_intrin(
                        "int32",
                        "tir.vta.uop_push",
                        1,
                        0,
                        dst_coeff[len(dst_coeff) - 1],
                        src_coeff[len(src_coeff) - 1],
                        0,
                        alu_opcode,
                        use_imm,
                        imm_val,
                    )
                )
                for extent in extents:
                    irb.emit(tvm.tir.call_extern("int32", "VTAUopLoopEnd"))
                return irb.get()
            return stmt

        return func.with_body(
            tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"])
        )

    return tvm.tir.transform.prim_func_pass(
        _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin"
    )