def InjectDMAIntrin()

in vta/python/vta/transform.py [0:0]


def InjectDMAIntrin():
    """Pass to inject DMA copy intrinsics.

    Returns
    -------
    fpass : tvm.transform.Pass
        The pass
    """
    idxd = tvm.tir.indexdiv
    idxm = tvm.tir.indexmod

    def _check_compact(buf):
        ndim = len(buf.shape)
        size = tvm.tir.const(1, buf.shape[0].dtype)
        for i in reversed(range(ndim)):
            if not utils.equal_const_int(size - buf.strides[i], 0):
                raise RuntimeError(
                    "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides)
                )
            size = size * buf.shape[i]

    def _fold_buffer_dim(buf, scope, elem_block):
        ndim = len(buf.shape)
        x_size = 1
        base = 0
        for i in range(1, ndim + 1):
            if not utils.equal_const_int(buf.strides[ndim - i] - x_size, 0):
                raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
            x_size = x_size * buf.shape[ndim - i]
            if utils.equal_const_int(x_size - elem_block, 0):
                base = i + 1
                break
        if base == 0:
            raise RuntimeError(
                "scope %s need to have block=%d, shape=%s" % (scope, elem_block, buf.shape)
            )
        shape = [elem_block]
        strides = [1]

        if base < ndim + 1 and not utils.equal_const_int(buf.strides[ndim - base], elem_block):
            shape.append(1)
            strides.append(elem_block)

        analyzer = tvm.arith.Analyzer()
        while base < ndim + 1:
            x_size = 1
            x_stride = buf.strides[ndim - base]
            next_base = base
            if not utils.equal_const_int(idxm(x_stride, elem_block), 0):
                raise RuntimeError(
                    "scope %s need to have block=%d, shape=%s, strides=%s"
                    % (scope, elem_block, buf.shape, buf.strides)
                )
            for i in range(base, ndim + 1):
                k = ndim - i
                if not utils.equal_const_int(x_size * x_stride - buf.strides[k], 0):
                    break
                x_size = x_size * buf.shape[k]
                next_base = i + 1
            shape.append(analyzer.simplify(x_size))
            strides.append(x_stride)
            assert next_base != base
            base = next_base

        strides = list(reversed(strides))
        shape = list(reversed(shape))
        return shape, strides

    def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
        elem_block = elem_bytes * 8 // elem_width
        shape, strides = buf.shape, buf.strides
        if not utils.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
            raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
        if allow_fold:
            shape, strides = _fold_buffer_dim(buf, scope, elem_block)
        else:
            shape = list(x for x in shape)
            strides = list(x for x in strides)

        def raise_error():
            """Internal function to raise error"""
            raise RuntimeError(
                (
                    "Scope[%s]: cannot detect 2d pattern with elem_block=%d:"
                    + " shape=%s, strides=%s"
                )
                % (scope, elem_block, buf.shape, buf.strides)
            )

        ndim = len(shape)

        # Check if the inner-tensor is already flat
        flat = utils.equal_const_int(shape[-1], elem_block)

        if flat:
            if not utils.equal_const_int(strides[-1], 1):
                raise_error()

            if ndim == 1:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
            if not utils.equal_const_int(strides[-2] - elem_block, 0):
                raise_error()

            if ndim == 2:
                x_size = shape[-2]
                x_stride = shape[-2]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
            if not utils.equal_const_int(idxm(strides[-3], elem_block), 0):
                raise_error()

            if ndim == 3:
                x_size = shape[-2]
                x_stride = idxd(strides[-3], elem_block)
                y_size = shape[-3]
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)

        else:
            if not utils.equal_const_int(strides[-1], 1):
                raise_error()
            if not utils.equal_const_int(strides[-2] - shape[-1], 0):
                raise_error()
            if not utils.equal_const_int(shape[-1] * shape[-2], elem_block):
                raise_error()

            if ndim == 2:
                x_size = 1
                x_stride = 1
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
            if not utils.equal_const_int(strides[-3], elem_block):
                raise_error()

            if ndim == 3:
                x_size = shape[-3]
                x_stride = shape[-3]
                y_size = 1
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
            if not utils.equal_const_int(idxm(strides[-4], elem_block), 0):
                raise_error()

            if ndim == 4:
                x_size = shape[-3]
                x_stride = idxd(strides[-4], elem_block)
                y_size = shape[-4]
                return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)

        raise_error()

    def _inject_copy(src, dst, pad_before, pad_after, pad_value):
        # FIXME: pad_value is ignored...
        env = get_env()
        _ = pad_value
        if dst.scope == "global":
            # Store
            if pad_before or pad_after:
                raise RuntimeError("Do not support copy into DRAM with pad")
            if src.scope == env.acc_scope:
                elem_width = env.OUT_WIDTH
                elem_bytes = env.OUT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_OUT
                data_type = "int%d" % env.OUT_WIDTH
                task_qid = env.dev.QID_STORE_OUT
            else:
                raise RuntimeError("Do not support copy %s->dram" % (src.scope))
            _check_compact(src)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True
            )
            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))
            irb.emit(
                tvm.tir.call_extern(
                    "int32",
                    "VTAStoreBuffer2D",
                    env.dev.command_handle,
                    src.access_ptr("r", "int32"),
                    mem_type,
                    dst.data,
                    offset,
                    x_size,
                    y_size,
                    x_stride,
                )
            )
            return irb.get()
        elif src.scope == "global":
            if dst.scope == env.acc_scope:
                elem_width = env.ACC_WIDTH
                elem_bytes = env.ACC_ELEM_BYTES
                mem_type = env.dev.MEM_ID_ACC
                data_type = "int%d" % env.ACC_WIDTH
                task_qid = env.dev.QID_LOAD_OUT
            elif dst.scope == env.inp_scope:
                elem_width = env.INP_WIDTH
                elem_bytes = env.INP_ELEM_BYTES
                mem_type = env.dev.MEM_ID_INP
                data_type = "int%d" % env.INP_WIDTH
                task_qid = env.dev.QID_LOAD_INP
            elif dst.scope == env.wgt_scope:
                elem_width = env.WGT_WIDTH
                elem_bytes = env.WGT_ELEM_BYTES
                mem_type = env.dev.MEM_ID_WGT
                data_type = "int%d" % env.WGT_WIDTH
                task_qid = env.dev.QID_LOAD_WGT
            else:
                raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
            # collect pad statistics
            if pad_before:
                assert pad_after
                ndim = len(pad_before)
                if ndim <= 2 or ndim > 5:
                    raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
                if ndim == 5:
                    # This case occurs when batch size N > 1
                    y_pad_before = pad_before[1]
                    x_pad_before = pad_before[2]
                    y_pad_after = pad_after[1]
                    x_pad_after = pad_after[2]
                    for dim in range(3, ndim):
                        if not utils.equal_const_int(pad_before[dim], 0):
                            raise ValueError("Do not support pad on the innermost block")
                        if not utils.equal_const_int(pad_after[dim], 0):
                            raise ValueError("Do not support pad on the innermost block")
                else:
                    y_pad_before = pad_before[0]
                    x_pad_before = pad_before[1]
                    y_pad_after = pad_after[0]
                    x_pad_after = pad_after[1]
                    for dim in range(2, ndim):
                        if not utils.equal_const_int(pad_before[dim], 0):
                            raise ValueError("Do not support pad on the innermost block")
                        if not utils.equal_const_int(pad_after[dim], 0):
                            raise ValueError("Do not support pad on the innermost block")
                allow_fold = False
            else:
                x_pad_before = 0
                y_pad_before = 0
                x_pad_after = 0
                y_pad_after = 0
                allow_fold = True

            _check_compact(dst)
            x_size, y_size, x_stride, offset = _get_2d_pattern(
                src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold
            )

            if data_type != src.dtype:
                assert data_type == "int%d" % env.ACC_WIDTH and src.dtype == "int%d" % env.INP_WIDTH
                mem_type = env.dev.MEM_ID_ACC_8BIT

            irb = tvm.tir.ir_builder.create()
            irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid))

            irb.emit(
                tvm.tir.call_extern(
                    "int32",
                    "VTALoadBuffer2D",
                    env.dev.command_handle,
                    src.data,
                    offset,
                    x_size,
                    y_size,
                    x_stride,
                    x_pad_before,
                    y_pad_before,
                    x_pad_after,
                    y_pad_after,
                    dst.access_ptr("r", "int32"),
                    mem_type,
                )
            )
            return irb.get()

        else:
            raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))

    return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)