in vta/python/vta/transform.py [0:0]
def InjectDMAIntrin():
"""Pass to inject DMA copy intrinsics.
Returns
-------
fpass : tvm.transform.Pass
The pass
"""
idxd = tvm.tir.indexdiv
idxm = tvm.tir.indexmod
def _check_compact(buf):
ndim = len(buf.shape)
size = tvm.tir.const(1, buf.shape[0].dtype)
for i in reversed(range(ndim)):
if not utils.equal_const_int(size - buf.strides[i], 0):
raise RuntimeError(
"Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)
)
size = size * buf.shape[i]
def _fold_buffer_dim(buf, scope, elem_block):
ndim = len(buf.shape)
x_size = 1
base = 0
for i in range(1, ndim + 1):
if not utils.equal_const_int(buf.strides[ndim - i] - x_size, 0):
raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
x_size = x_size * buf.shape[ndim - i]
if utils.equal_const_int(x_size - elem_block, 0):
base = i + 1
break
if base == 0:
raise RuntimeError(
"scope %s need to have block=%d, shape=%s" % (scope, elem_block, buf.shape)
)
shape = [elem_block]
strides = [1]
if base < ndim + 1 and not utils.equal_const_int(buf.strides[ndim - base], elem_block):
shape.append(1)
strides.append(elem_block)
analyzer = tvm.arith.Analyzer()
while base < ndim + 1:
x_size = 1
x_stride = buf.strides[ndim - base]
next_base = base
if not utils.equal_const_int(idxm(x_stride, elem_block), 0):
raise RuntimeError(
"scope %s need to have block=%d, shape=%s, strides=%s"
% (scope, elem_block, buf.shape, buf.strides)
)
for i in range(base, ndim + 1):
k = ndim - i
if not utils.equal_const_int(x_size * x_stride - buf.strides[k], 0):
break
x_size = x_size * buf.shape[k]
next_base = i + 1
shape.append(analyzer.simplify(x_size))
strides.append(x_stride)
assert next_base != base
base = next_base
strides = list(reversed(strides))
shape = list(reversed(shape))
return shape, strides
def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
elem_block = elem_bytes * 8 // elem_width
shape, strides = buf.shape, buf.strides
if not utils.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
if allow_fold:
shape, strides = _fold_buffer_dim(buf, scope, elem_block)
else:
shape = list(x for x in shape)
strides = list(x for x in strides)
def raise_error():
"""Internal function to raise error"""
raise RuntimeError(
(
"Scope[%s]: cannot detect 2d pattern with elem_block=%d:"
+ " shape=%s, strides=%s"
)
% (scope, elem_block, buf.shape, buf.strides)
)
ndim = len(shape)
# Check if the inner-tensor is already flat
flat = utils.equal_const_int(shape[-1], elem_block)
if flat:
if not utils.equal_const_int(strides[-1], 1):
raise_error()
if ndim == 1:
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not utils.equal_const_int(strides[-2] - elem_block, 0):
raise_error()
if ndim == 2:
x_size = shape[-2]
x_stride = shape[-2]
y_size = 1
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not utils.equal_const_int(idxm(strides[-3], elem_block), 0):
raise_error()
if ndim == 3:
x_size = shape[-2]
x_stride = idxd(strides[-3], elem_block)
y_size = shape[-3]
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
else:
if not utils.equal_const_int(strides[-1], 1):
raise_error()
if not utils.equal_const_int(strides[-2] - shape[-1], 0):
raise_error()
if not utils.equal_const_int(shape[-1] * shape[-2], elem_block):
raise_error()
if ndim == 2:
x_size = 1
x_stride = 1
y_size = 1
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not utils.equal_const_int(strides[-3], elem_block):
raise_error()
if ndim == 3:
x_size = shape[-3]
x_stride = shape[-3]
y_size = 1
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
if not utils.equal_const_int(idxm(strides[-4], elem_block), 0):
raise_error()
if ndim == 4:
x_size = shape[-3]
x_stride = idxd(strides[-4], elem_block)
y_size = shape[-4]
return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
raise_error()
def _inject_copy(src, dst, pad_before, pad_after, pad_value):
# FIXME: pad_value is ignored...
env = get_env()
_ = pad_value
if dst.scope == "global":
# Store
if pad_before or pad_after:
raise RuntimeError("Do not support copy into DRAM with pad")
if src.scope == env.acc_scope:
elem_width = env.OUT_WIDTH
elem_bytes = env.OUT_ELEM_BYTES
mem_type = env.dev.MEM_ID_OUT
data_type = "int%d" % env.OUT_WIDTH
task_qid = env.dev.QID_STORE_OUT
else:
raise RuntimeError("Do not support copy %s->dram" % (src.scope))
_check_compact(src)
x_size, y_size, x_stride, offset = _get_2d_pattern(
dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
)
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
irb.emit(
tvm.tir.call_extern(
"int32",
"VTAStoreBuffer2D",
env.dev.command_handle,
src.access_ptr("r", "int32"),
mem_type,
dst.data,
offset,
x_size,
y_size,
x_stride,
)
)
return irb.get()
elif src.scope == "global":
if dst.scope == env.acc_scope:
elem_width = env.ACC_WIDTH
elem_bytes = env.ACC_ELEM_BYTES
mem_type = env.dev.MEM_ID_ACC
data_type = "int%d" % env.ACC_WIDTH
task_qid = env.dev.QID_LOAD_OUT
elif dst.scope == env.inp_scope:
elem_width = env.INP_WIDTH
elem_bytes = env.INP_ELEM_BYTES
mem_type = env.dev.MEM_ID_INP
data_type = "int%d" % env.INP_WIDTH
task_qid = env.dev.QID_LOAD_INP
elif dst.scope == env.wgt_scope:
elem_width = env.WGT_WIDTH
elem_bytes = env.WGT_ELEM_BYTES
mem_type = env.dev.MEM_ID_WGT
data_type = "int%d" % env.WGT_WIDTH
task_qid = env.dev.QID_LOAD_WGT
else:
raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
# collect pad statistics
if pad_before:
assert pad_after
ndim = len(pad_before)
if ndim <= 2 or ndim > 5:
raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
if ndim == 5:
# This case occurs when batch size N > 1
y_pad_before = pad_before[1]
x_pad_before = pad_before[2]
y_pad_after = pad_after[1]
x_pad_after = pad_after[2]
for dim in range(3, ndim):
if not utils.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not utils.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
else:
y_pad_before = pad_before[0]
x_pad_before = pad_before[1]
y_pad_after = pad_after[0]
x_pad_after = pad_after[1]
for dim in range(2, ndim):
if not utils.equal_const_int(pad_before[dim], 0):
raise ValueError("Do not support pad on the innermost block")
if not utils.equal_const_int(pad_after[dim], 0):
raise ValueError("Do not support pad on the innermost block")
allow_fold = False
else:
x_pad_before = 0
y_pad_before = 0
x_pad_after = 0
y_pad_after = 0
allow_fold = True
_check_compact(dst)
x_size, y_size, x_stride, offset = _get_2d_pattern(
src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
)
if data_type != src.dtype:
assert data_type == "int%d" % env.ACC_WIDTH and src.dtype == "int%d" % env.INP_WIDTH
mem_type = env.dev.MEM_ID_ACC_8BIT
irb = tvm.tir.ir_builder.create()
irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
irb.emit(
tvm.tir.call_extern(
"int32",
"VTALoadBuffer2D",
env.dev.command_handle,
src.data,
offset,
x_size,
y_size,
x_stride,
x_pad_before,
y_pad_before,
x_pad_after,
y_pad_after,
dst.access_ptr("r", "int32"),
mem_type,
)
)
return irb.get()
else:
raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)