in vta/python/vta/transform.py [0:0]
def InjectALUIntrin():
"""Pass to inject ALU micro-ops.
Returns
-------
fpass : tvm.transform.Pass
The pass
"""
def _ftransform(func, mod, ctx):
env = get_env()
idxm = tvm.tir.indexmod
analyzer = tvm.arith.Analyzer()
def _do_fold(stmt):
def _equal(x, y):
return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
def _flatten_loop(src_coeff, dst_coeff, extents):
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
rev_src_coeff = [src_coeff.pop()]
rev_dst_coeff = [dst_coeff.pop()]
rev_extents = []
assert src_coeff
vsrc = src_coeff.pop()
vdst = dst_coeff.pop()
vext = extents.pop()
while src_coeff:
next_src = src_coeff.pop()
next_dst = dst_coeff.pop()
next_ext = extents.pop()
if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
vext = analyzer.simplify(vext * next_ext)
else:
rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst)
rev_extents.append(vext)
vsrc = next_src
vdst = next_dst
vext = next_ext
rev_src_coeff.append(vsrc)
rev_dst_coeff.append(vdst)
rev_extents.append(vext)
rev_src_coeff.reverse()
rev_dst_coeff.reverse()
rev_extents.reverse()
return rev_src_coeff, rev_dst_coeff, rev_extents
if _match_pragma(stmt, "alu"):
# Get to the innermost loop body
loop_body = stmt.body
nest_size = 0
while isinstance(loop_body, tvm.tir.For):
loop_body = loop_body.body
nest_size += 1
# Get the src/dst arguments
dst_var = loop_body.buffer_var
dst_idx = loop_body.index
# Derive loop variables and extents
tmp_body = stmt.body
indices = []
extents = []
for _ in range(nest_size):
indices.append(tmp_body.loop_var)
extents.append(tmp_body.extent)
tmp_body = tmp_body.body
# Derive opcode
if isinstance(loop_body.value, tvm.tir.Add):
alu_opcode = env.dev.ALU_OPCODE_ADD
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Sub):
alu_opcode = env.dev.ALU_OPCODE_SUB
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Mul):
alu_opcode = env.dev.ALU_OPCODE_MUL
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Min):
alu_opcode = env.dev.ALU_OPCODE_MIN
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Max):
alu_opcode = env.dev.ALU_OPCODE_MAX
lhs = loop_body.value.a
rhs = loop_body.value.b
elif isinstance(loop_body.value, tvm.tir.Call):
if loop_body.value.op.name == "tir.shift_left":
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = analyzer.simplify(-loop_body.value.args[1])
elif loop_body.value.op.name == "tir.shift_right":
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value.args[0]
rhs = loop_body.value.args[1]
else:
raise RuntimeError(
"Function call not recognized %s" % (loop_body.value.name)
)
elif isinstance(loop_body.value, tvm.tir.Load):
alu_opcode = env.dev.ALU_OPCODE_SHR
lhs = loop_body.value
rhs = tvm.tir.const(0, "int32")
else:
raise RuntimeError(
"Expression not recognized %s, %s, %s"
% (type(loop_body.value), str(loop_body.value), str(stmt))
)
# Derive array index coefficients
dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
# Check if lhs/rhs is immediate
use_imm = False
imm_val = None
if isinstance(rhs, tvm.tir.IntImm):
assert lhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
use_imm = True
imm_val = rhs
if isinstance(lhs, tvm.tir.IntImm):
assert rhs.buffer_var.same_as(dst_var)
src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
use_imm = True
imm_val = lhs
if imm_val is None:
imm_val = 0
assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
# Determine which side has the same coefficients
lhs_equal = True
rhs_equal = True
for i, coef in enumerate(dst_coeff):
if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
lhs_equal = False
if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
rhs_equal = False
# Make sure at least one of the source is identical to the
# destination (in-place computation)
assert lhs_equal or rhs_equal
# Assign the source coefficients
if lhs_equal:
src_coeff = src_rhs_coeff
else:
src_coeff = src_lhs_coeff
# Ensure that we have the proper tensor dimensions in the
# innermost loop (pattern match)
src_coeff = list(src_coeff)
dst_coeff = list(dst_coeff)
extents = list(extents)
assert len(src_coeff) > 1
assert len(dst_coeff) > 1
assert len(extents) != 0
assert tvm.ir.structural_equal(
analyzer.simplify(idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
)
assert tvm.ir.structural_equal(
analyzer.simplify(idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0
)
assert tvm.ir.structural_equal(src_coeff[-2], 1)
assert tvm.ir.structural_equal(dst_coeff[-2], 1)
if env.BATCH > 1:
assert len(src_coeff) > 2
assert len(dst_coeff) > 2
assert len(extents) > 1
assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
# Apply tensorization of the loop coefficients
src_offset = src_coeff[-1]
dst_offset = dst_coeff[-1]
if env.BATCH == 1:
src_coeff = src_coeff[:-2]
dst_coeff = dst_coeff[:-2]
extents = extents[:-1]
else:
src_coeff = src_coeff[:-3]
dst_coeff = dst_coeff[:-3]
extents = extents[:-2]
src_coeff.append(src_offset)
dst_coeff.append(dst_offset)
src_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
dst_coeff = [analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
# Flatten the outer loops
if extents:
src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
# Insert ALU micro-ops
irb = tvm.tir.ir_builder.create()
for idx, extent in enumerate(extents):
irb.emit(
tvm.tir.call_extern(
"int32", "VTAUopLoopBegin", extent, dst_coeff[idx], src_coeff[idx], 0
)
)
use_imm = int(use_imm)
irb.emit(
tvm.tir.call_intrin(
"int32",
"tir.vta.uop_push",
1,
0,
dst_coeff[len(dst_coeff) - 1],
src_coeff[len(src_coeff) - 1],
0,
alu_opcode,
use_imm,
imm_val,
)
)
for extent in extents:
irb.emit(tvm.tir.call_extern("int32", "VTAUopLoopEnd"))
return irb.get()
return stmt
return func.with_body(
tvm.tir.stmt_functor.ir_transform(func.body, None, _do_fold, ["tir.AttrStmt"])
)
return tvm.tir.transform.prim_func_pass(
_ftransform, opt_level=0, name="tir.vta.InjectALUIntrin"
)