in python/featgraph/op/vanilla_spmm.py [0:0]
def schedule_vanilla_spmm_csr_cuda(Out,
num_cuda_blocks=None,
num_threads_per_cuda_block=None):
s = te.create_schedule([Out.op])
num_rows = Out.shape[0].value
feat_len = Out.shape[1].value
if num_cuda_blocks is None:
num_cuda_blocks = num_rows
if num_threads_per_cuda_block is None:
num_threads_per_cuda_block = feat_len
row_axis = Out.op.axis[0]
feat_axis = Out.op.axis[1]
row_outer, row_inner = s[Out.op].split(row_axis, nparts=num_cuda_blocks)
feat_outer, feat_inner = s[Out.op].split(feat_axis, factor=num_threads_per_cuda_block)
s[Out.op].reorder(feat_outer, row_outer, feat_inner, row_inner)
s[Out.op].bind(feat_outer, te.thread_axis("blockIdx.y"))
s[Out.op].bind(row_outer, te.thread_axis("blockIdx.x"))
s[Out.op].bind(feat_inner, te.thread_axis("threadIdx.x"))
return s