def schedule_ndhwc_tensorcore_cuda()

in python/tvm/topi/cuda/conv3d_ndhwc_tensorcore.py [0:0]


def schedule_ndhwc_tensorcore_cuda(cfg, s, Conv):
    """Schedule tensorcore template"""
    kd, kh, kw, ic = s[Conv].op.reduce_axis
    out_dtype = Conv.dtype
    trans_paddata, kernel = s[Conv].op.input_tensors
    in_dtype = trans_paddata.dtype
    batch, _, _, _, _ = get_const_tuple(Conv.shape)
    _, _, _, _, out_channels = get_const_tuple(kernel.shape)
    paddata = s[trans_paddata].op.input_tensors

    # inline the pad and dtype transform
    s[trans_paddata].compute_inline()
    s[kernel].compute_inline()
    s[paddata[0]].compute_inline()

    # Designate the memory hierarchy
    AS = s.cache_read(trans_paddata, "shared", [Conv])
    WS = s.cache_read(kernel, "shared", [Conv])
    AF = s.cache_read(AS, "wmma.matrix_a", [Conv])
    WF = s.cache_read(WS, "wmma.matrix_b", [Conv])
    ConvF = s.cache_write(Conv, "wmma.accumulator")

    if Conv.op in s.outputs:
        output = Conv
        ConvS = s.cache_read(ConvF, "shared", [Conv])
        OL = ConvS
    else:
        output = s.outputs[0].output(0)
        s[Conv].set_scope("shared")
        OL = Conv

    # Schedule for autotvm
    cfg.define_knob("block_row_warps", [1, 2, 4])
    cfg.define_knob("block_col_warps", [1, 2, 4])
    cfg.define_knob("warp_row_tiles", [1, 2, 4])
    cfg.define_knob("warp_col_tiles", [1, 2, 4])
    cfg.define_knob("chunk", [1, 2, 4, 8])
    cfg.define_knob("offset", [0, 8])
    cfg.define_knob("vector_width", [1, 2, 4, 8])

    if batch % 16 == 0 and out_channels % 16 == 0:
        cfg.define_knob("wmma_m", [16, 8, 32])
    elif batch % 8 == 0 and out_channels % 32 == 0:
        cfg.define_knob("wmma_m", [8, 16, 32])
    elif batch % 32 == 0 and out_channels % 8 == 0:
        cfg.define_knob("wmma_m", [32, 16, 8])

    # fallback support
    target = tvm.target.Target.current()
    if cfg.is_fallback:
        ref_log = autotvm.tophub.load_reference_log(
            target.kind.name, target.model, "conv3d_ndhwc_tensorcore.cuda"
        )
        cfg.fallback_with_reference_log(ref_log)

    block_row_warps = cfg["block_row_warps"].val
    block_col_warps = cfg["block_col_warps"].val
    warp_row_tiles = cfg["warp_row_tiles"].val
    warp_col_tiles = cfg["warp_col_tiles"].val
    chunk = cfg["chunk"].val
    offset = cfg["offset"].val
    wmma_m = cfg["wmma_m"].val
    vector_width = cfg["vector_width"].val

    wmma_k = 16
    if wmma_m == 16:
        wmma_n = 16
    elif wmma_m == 8:
        wmma_n = 32
    elif wmma_m == 32:
        wmma_n = 8

    warp_size = 32

    block_x = te.thread_axis("blockIdx.x")
    block_y = te.thread_axis("blockIdx.y")
    block_z = te.thread_axis("blockIdx.z")
    thread_x = te.thread_axis("threadIdx.x")
    thread_y = te.thread_axis("threadIdx.y")
    thread_z = te.thread_axis("threadIdx.z")

    # Define the intrin strides
    def get_strides(extents):
        return [np.prod(extents[i:]).tolist() for i in range(len(extents))]

    AS_align = chunk * wmma_k + offset
    WS_align = warp_col_tiles * block_col_warps * wmma_n + offset
    block_factor_n = wmma_m * warp_row_tiles * block_row_warps
    block_factor_o = wmma_n * warp_col_tiles * block_col_warps
    CS_align = block_factor_o + offset
    AS_strides = get_strides([1, 1, 1, AS_align, 1])
    AL_strides = get_strides([1, 1, 1, wmma_k, 1])
    WS_strides = get_strides([WS_align, 1])
    WL_strides = get_strides([wmma_n * warp_col_tiles, 1])
    CL_strides = get_strides([1, 1, 1, wmma_n * warp_col_tiles, 1])
    CS_strides = get_strides([1, 1, 1, CS_align, 1])

    # Schedule for output
    nc, dc, hc, wc, oc = output.op.axis
    block_k = s[output].fuse(dc, hc, wc)
    s[output].bind(block_k, block_z)
    block_i, nc = s[output].split(nc, factor=block_factor_n)
    block_j, oc = s[output].split(oc, factor=block_factor_o)
    s[output].reorder(block_k, block_i, block_j, nc, oc)
    t = s[output].fuse(nc, oc)
    t, ti = s[output].split(t, factor=vector_width)
    t, tx = s[output].split(t, factor=warp_size)
    t, ty = s[output].split(t, factor=block_row_warps)
    t, tz = s[output].split(t, factor=block_col_warps)
    s[output].bind(block_i, block_x)
    s[output].bind(block_j, block_y)
    s[output].bind(tz, thread_z)
    s[output].bind(ty, thread_y)
    s[output].bind(tx, thread_x)
    s[output].vectorize(ti)

    # Schedule wmma store
    s[OL].compute_at(s[output], block_j)
    nc, dc, hc, wc, oc = OL.op.axis
    s[OL].reorder(dc, hc, wc, nc, oc)
    s[OL].storage_align(wc, CS_align - 1, CS_align)
    oc, ooc = s[OL].split(oc, factor=wmma_n)
    oc, oci = s[OL].split(oc, factor=warp_col_tiles)
    _, oc = s[OL].split(oc, factor=block_col_warps)
    nc, nnc = s[OL].split(nc, factor=wmma_m)
    nc, nci = s[OL].split(nc, factor=warp_row_tiles)
    _, nc = s[OL].split(nc, factor=block_row_warps)
    s[OL].reorder(nc, oc, nci, oci, nnc, ooc)
    s[OL].bind(nc, thread_y)
    s[OL].bind(oc, thread_z)

    # Schedule wmma computation
    s[ConvF].compute_at(s[OL], oc)
    n, d, h, w, o = ConvF.op.axis
    n, nnf = s[ConvF].split(n, factor=wmma_m)
    o, oof = s[ConvF].split(o, factor=wmma_n)
    ic, ii = s[ConvF].split(ic, factor=wmma_k)
    ko, ki = s[ConvF].split(ic, factor=chunk)
    s[ConvF].reorder(kd, kh, kw, ko, ki, n, o, nnf, oof, ii)

    s[AF].compute_at(s[ConvF], ki)
    s[WF].compute_at(s[ConvF], ki)

    # Schedule wmma load
    n, d, h, w, i = AF.op.axis
    n, nn = s[AF].split(n, factor=wmma_m)
    i, ii = s[AF].split(i, factor=wmma_k)
    s[AF].reorder(n, i, nn, ii)

    kd, kh, kw, i, o = WF.op.axis
    i, ii = s[WF].split(i, factor=wmma_k)
    o, oo = s[WF].split(o, factor=wmma_n)
    s[WF].reorder(o, i, oo)
    s[WF].reorder(i, o, ii, oo)

    s[WS].compute_at(s[ConvF], ko)
    s[AS].compute_at(s[ConvF], ko)

    # Schedule for data's share memory
    n, d, h, w, i = AS.op.axis
    s[AS].reorder(d, h, w, n, i)
    s[AS].storage_align(w, AS_align - 1, AS_align)
    t = s[AS].fuse(n, i)
    t, ti = s[AS].split(t, factor=vector_width)
    t, tx = s[AS].split(t, factor=warp_size)
    t, ty = s[AS].split(t, factor=block_row_warps)
    _, tz = s[AS].split(t, factor=block_col_warps)
    s[AS].bind(ty, thread_y)
    s[AS].bind(tz, thread_z)
    s[AS].bind(tx, thread_x)
    s[AS].vectorize(ti)

    # Schedule for kernel's share memory
    kd, kh, kw, ic, o = WS.op.axis
    t = s[WS].fuse(ic, o)
    s[WS].storage_align(ic, WS_align - 1, WS_align)
    t, ti = s[WS].split(t, factor=vector_width)
    t, tx = s[WS].split(t, factor=warp_size)
    t, ty = s[WS].split(t, factor=block_row_warps)
    _, tz = s[WS].split(t, factor=block_col_warps)
    s[WS].bind(ty, thread_y)
    s[WS].bind(tz, thread_z)
    s[WS].bind(tx, thread_x)
    s[WS].vectorize(ti)

    shape = (wmma_m, wmma_n, wmma_k)

    # tensorize the wmma process
    AS_shape = (wmma_m, 1, 1, 1, wmma_k)
    AL_shape = (wmma_m, 1, 1, 1, wmma_k)
    WS_shape = (wmma_k, wmma_n)
    WL_shape = (wmma_k, wmma_n)
    CL_shape = (wmma_m, 1, 1, 1, wmma_n)
    CS_shape = (wmma_m, 1, 1, 1, wmma_n)

    AL_gemm = te.placeholder(AL_shape, name="A", dtype=in_dtype)
    WL_gemm = te.placeholder(WL_shape, name="B", dtype=in_dtype)
    k_gemm = te.reduce_axis((0, wmma_k), name="k")
    CL_compute = te.compute(
        CL_shape,
        lambda ii, t0, t1, t2, jj: te.sum(
            AL_gemm[ii, t0, t1, t2, k_gemm].astype(out_dtype)
            * WL_gemm[k_gemm, jj].astype(out_dtype),
            axis=k_gemm,
        ),
        name="C",
    )

    s[AF].tensorize(
        nn,
        intrin_wmma_load_matrix_A(
            AL_strides, AS_strides, shape, "row_major", AS_shape, AL_shape, in_dtype
        ),
    )
    s[WF].tensorize(
        ii,
        intrin_wmma_load_matrix_W(
            WL_strides, WS_strides, shape, "row_major", WS_shape, WL_shape, in_dtype
        ),
    )
    s[OL].tensorize(
        nnc, intrin_wmma_store_matrix(CS_strides, CL_strides, shape, out_dtype, CL_shape, CS_shape)
    )
    s[ConvF].tensorize(
        nnf,
        intrin_wmma_gemm(AL_gemm, WL_gemm, CL_compute, AL_strides, WL_strides, CL_strides, shape),
    )

    N, OD, OH, OW, CO = get_const_tuple(output.shape)
    KD, KH, KW, CI, _ = get_const_tuple(kernel.shape)
    cfg.add_flop(2 * N * OD * OH * OW * CO * CI * KD * KH * KW)