in python/featgraph/op/vanilla_sddmm.py [0:0]
def schedule_vanilla_sddmm_cuda_tree_reduce(Out, num_feat_partitions=1, num_cuda_blocks=8192):
s = te.create_schedule([Out.op])
assert num_feat_partitions == 1, "cuda schedule for sddmm does not support feat dimension tiling, " \
"which requires cross-cuda-block reduction and atomic operations."
num_edges = Out.shape[0].value
assert num_cuda_blocks <= num_edges, "num_cuda_blocks must be smaller than num_edges, " \
"which is {}.".format(num_edges)
edge_iter_axis = Out.op.axis[0]
block_idx, _ = s[Out.op].split(edge_iter_axis, nparts=num_cuda_blocks)
s[Out.op].bind(block_idx, te.thread_axis("blockIdx.x"))
# Pay attention: here is doing tree reduce
s[Out.op].bind(Out.op.reduce_axis[0], te.thread_axis("threadIdx.x"))
return s