def schedule_vanilla_spmm_csr_cuda()

in python/featgraph/op/vanilla_spmm.py [0:0]


def schedule_vanilla_spmm_csr_cuda(Out,
                                   num_cuda_blocks=None,
                                   num_threads_per_cuda_block=None):
    s = te.create_schedule([Out.op])
    num_rows = Out.shape[0].value
    feat_len = Out.shape[1].value
    if num_cuda_blocks is None:
        num_cuda_blocks = num_rows
    if num_threads_per_cuda_block is None:
        num_threads_per_cuda_block = feat_len
    row_axis = Out.op.axis[0]
    feat_axis = Out.op.axis[1]
    row_outer, row_inner = s[Out.op].split(row_axis, nparts=num_cuda_blocks)
    feat_outer, feat_inner = s[Out.op].split(feat_axis, factor=num_threads_per_cuda_block)
    s[Out.op].reorder(feat_outer, row_outer, feat_inner, row_inner)
    s[Out.op].bind(feat_outer, te.thread_axis("blockIdx.y"))
    s[Out.op].bind(row_outer, te.thread_axis("blockIdx.x"))
    s[Out.op].bind(feat_inner, te.thread_axis("threadIdx.x"))
    return s