def schedule_vanilla_sddmm_cuda_tree_reduce()

in python/featgraph/op/vanilla_sddmm.py [0:0]


def schedule_vanilla_sddmm_cuda_tree_reduce(Out, num_feat_partitions=1, num_cuda_blocks=8192):
    s = te.create_schedule([Out.op])
    assert num_feat_partitions == 1, "cuda schedule for sddmm does not support feat dimension tiling, " \
                                     "which requires cross-cuda-block reduction and atomic operations."
    num_edges = Out.shape[0].value
    assert num_cuda_blocks <= num_edges, "num_cuda_blocks must be smaller than num_edges, " \
                                         "which is {}.".format(num_edges)
    edge_iter_axis = Out.op.axis[0]
    block_idx, _ = s[Out.op].split(edge_iter_axis, nparts=num_cuda_blocks)
    s[Out.op].bind(block_idx, te.thread_axis("blockIdx.x"))
    # Pay attention: here is doing tree reduce
    s[Out.op].bind(Out.op.reduce_axis[0], te.thread_axis("threadIdx.x"))
    return s