def schedule_vanilla_sddmm_x86()

in python/featgraph/op/vanilla_sddmm.py [0:0]


def schedule_vanilla_sddmm_x86(Out, num_feat_partitions=1):
    s = te.create_schedule([Out.op])
    if num_feat_partitions != 1:
        edge_iter_axis = Out.op.axis[0]
        feat_reduce_axis = Out.op.reduce_axis[0]
        fo, fi = s[Out.op].split(feat_reduce_axis, nparts=num_feat_partitions)
        s[Out.op].reorder(fo, edge_iter_axis, fi)
        # TODO: parallelize ReshapedSrcFeat and ReshapedDstFeat
    return s