def schedule_hwnc_tensorcore_cuda()

in python/tvm/topi/cuda/conv2d_hwnc_tensorcore.py [0:0]


def schedule_hwnc_tensorcore_cuda(cfg, s, Conv):
    """Schedule tensorcore template"""
    pad_data, packed_kernel = s[Conv].op.input_tensors
    ic, kh, kw, ii = s[Conv].op.reduce_axis
    packed_data = s[pad_data].op.input_tensors[0]

    block_x = te.thread_axis("blockIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    block_z = te.thread_axis("blockIdx.z")
    thread_x = te.thread_axis("threadIdx.x")
    thread_y = te.thread_axis("threadIdx.y")
    thread_z = te.thread_axis("threadIdx.z")

    # Designate the memory hierarchy
    AS = s.cache_read(pad_data, "shared", [Conv])
    WS = s.cache_read(packed_kernel, "shared", [Conv])
    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
    ConvF = s.cache_write(Conv, "wmma.accumulator")

    if Conv.op in s.outputs:
        output = Conv
        ConvS = s.cache_read(ConvF, "shared", [Conv])
        OL = ConvS
    else:
        output = s.outputs[0].output(0)
        s[Conv].set_scope("shared")
        OL = Conv

    out_dtype = Conv.dtype

    if isinstance(packed_kernel.op, te.tensor.ComputeOp) and packed_kernel.name == "packed_kernel":
        if autotvm.GLOBAL_SCOPE.in_tuning:
            s[packed_kernel].pragma(s[packed_kernel].op.axis[0], "debug_skip_region")
        else:
            with Target("cuda"):
                schedule_injective_from_existing(s, packed_kernel)

    if isinstance(pad_data.op, te.tensor.ComputeOp) and "pad" in pad_data.op.tag:
        s[pad_data].compute_inline()
        data = pad_data.op.input_tensors[0]

        if autotvm.GLOBAL_SCOPE.in_tuning:
            # skip this part during tuning to make recrods accurate
            # this part will be pre-computed during NNVM's pre-compute optimization pass
            s[pad_data].pragma(s[pad_data].op.axis[0], "debug_skip_region")
    else:
        data = pad_data
        s[data].compute_inline()

    data_dtype = data.dtype
    kernel_dtype = packed_kernel.dtype

    # Schedule for autotvm
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("warp_row_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("warp_col_tiles", [1, 2, 4, 8, 16])
    cfg.define_knob("chunk", [1, 2, 4, 8])
    cfg.define_knob("split_block_k_nums", [1, 2, 4, 8, 16, 32])
    cfg.define_knob("vector_ws", [1, 8])
    cfg.define_knob("vector_as", [1, 8, 16])

    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    warp_row_tiles = cfg["warp_row_tiles"].val
    warp_col_tiles = cfg["warp_col_tiles"].val
    chunk = cfg["chunk"].val
    vector_as = cfg["vector_as"].val
    vector_ws = cfg["vector_ws"].val
    split_block_k_nums = cfg["split_block_k_nums"].val

    s[packed_data].compute_inline()

    if data_dtype in ["int4", "uint4"]:
        wmma_m = wmma_n = 8
        wmma_k = 32
    else:
        wmma_m = 8
        wmma_n = 32
        wmma_k = 16

    warp_size = 32

    # Schedule for output
    if len(s[output].op.axis) == 4:
        (
            hc,
            wc,
            nc,
            oc,
        ) = output.op.axis
        nc, nnc = s[output].split(nc, factor=wmma_m)
        oc, ooc = s[output].split(oc, factor=wmma_n)
    else:
        hc, wc, nc, oc, nnc, ooc = output.op.axis

    kernel_scope, hc = s[output].split(hc, nparts=1)

    block_k = s[output].fuse(hc, wc)
    block_k, split_block_k = s[output].split(block_k, factor=split_block_k_nums)
    nc, nci = s[output].split(nc, factor=warp_row_tiles)
    block_i, nc = s[output].split(nc, factor=block_row_warps)
    oc, oci = s[output].split(oc, factor=warp_col_tiles)
    block_j, oc = s[output].split(oc, factor=block_col_warps)
    s[output].reorder(block_k, split_block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
    t = s[output].fuse(nnc, ooc)
    _, tx = s[output].split(t, factor=warp_size)
    s[output].bind(block_k, block_z)
    s[output].bind(block_i, block_x)
    s[output].bind(block_j, block_y)
    s[output].bind(tx, thread_x)
    s[output].bind(nc, thread_y)
    s[output].bind(oc, thread_z)

    # Schedule wmma store
    s[OL].compute_at(s[output], block_j)
    hc, wc, nc, oc, nnc, ooc = OL.op.axis
    oc, oci = s[OL].split(oc, factor=warp_col_tiles)
    _, oc = s[OL].split(oc, factor=block_col_warps)
    nc, nci = s[OL].split(nc, factor=warp_row_tiles)
    _, nc = s[OL].split(nc, factor=block_row_warps)
    s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
    s[OL].bind(nc, thread_y)
    s[OL].bind(oc, thread_z)

    # Schedule local computation
    s[ConvF].compute_at(s[OL], oc)
    _, _, n, o, nnf, oof = ConvF.op.axis
    ko, ki = s[ConvF].split(ic, factor=chunk)
    s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)

    cfg.define_reorder("reorder_inner", [ko, kh], policy="all")
    cfg["reorder_inner"].apply(s, ConvF, [ko, kh])
    cfg["reorder_inner"].apply(s, ConvF, [ki, kw])

    # Move intermediate computation into each output compute tile
    s[AF].compute_at(s[ConvF], kw)
    s[WF].compute_at(s[ConvF], kw)

    # Schedule for A's share memory
    s[AS].compute_at(s[ConvF], ko)

    _, _, n, _, nn, ii = AS.op.axis
    tx, xo = s[AS].split(n, nparts=block_row_warps)
    ty, _ = s[AS].split(xo, nparts=block_col_warps)
    t = s[AS].fuse(nn, ii)
    to, ti = s[AS].split(t, nparts=warp_size)
    ti, _t = s[AS].split(ti, factor=vector_as)
    s[AS].bind(tx, thread_y)
    s[AS].bind(ty, thread_z)
    s[AS].bind(to, thread_x)
    s[AS].vectorize(_t)

    # Schedule for W's share memory
    s[WS].compute_at(s[ConvF], kw)
    kh, kw, ic, o, ii, oo = WS.op.axis
    tx, xo = s[WS].split(o, nparts=block_row_warps)
    ty, _ = s[WS].split(xo, nparts=block_col_warps)
    t = s[WS].fuse(ii, oo)
    to, ti = s[WS].split(t, nparts=warp_size)
    ti, _t = s[WS].split(ti, factor=vector_ws)
    s[WS].bind(tx, thread_y)
    s[WS].bind(ty, thread_z)
    s[WS].bind(to, thread_x)
    s[WS].vectorize(ti)

    # double buffer
    cfg.define_knob("AS_double_buffer", [0, 1])
    cfg.define_knob("WS_double_buffer", [0, 1])
    if cfg["AS_double_buffer"].val:
        s[AS].double_buffer()
    if cfg["WS_double_buffer"].val:
        s[WS].double_buffer()

    # unroll
    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
    s[output].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
    s[output].pragma(kernel_scope, "unroll_explicit", False)

    shape = (wmma_m, wmma_n, wmma_k)

    AS_shape = (wmma_m, wmma_k)
    AL_shape = (wmma_m, wmma_k)
    WS_shape = (wmma_n, wmma_k)
    WL_shape = (wmma_n, wmma_k)
    CL_shape = (wmma_m, wmma_n)
    CS_shape = (wmma_m, wmma_n)

    AL_gemm = te.placeholder(AL_shape, name="A", dtype=data_dtype)
    WL_gemm = te.placeholder(WL_shape, name="B", dtype=kernel_dtype)
    k_gemm = te.reduce_axis((0, wmma_k), name="k")
    CL_compute = te.compute(
        CL_shape,
        lambda ii, jj: te.sum(
            (AL_gemm[ii, k_gemm].astype("int32") * WL_gemm[jj, k_gemm].astype("int32")), axis=k_gemm
        ),
        name="C",
    )

    AL_strides = [wmma_k, 1]
    AS_strides = [wmma_k, 1]
    WL_strides = [wmma_k, 1]
    WS_strides = [wmma_k, 1]
    CL_strides = [wmma_n, 1]
    CS_strides = [wmma_n, 1]

    s[AF].tensorize(
        AF.op.axis[-2],
        intrin_wmma_load_matrix_A(
            AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, data_dtype
        ),
    )

    s[WF].tensorize(
        WF.op.axis[-2],
        intrin_wmma_load_matrix_W(
            WL_strides, WS_strides, shape, "col_major", WS_shape, WL_shape, kernel_dtype
        ),
    )

    s[OL].tensorize(
        nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)
    )

    s[ConvF].tensorize(
        nnf,
        intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape),
    )

    return s