in python/featgraph/op/vanilla_sddmm.py [0:0]
def schedule_vanilla_sddmm_x86(Out, num_feat_partitions=1):
s = te.create_schedule([Out.op])
if num_feat_partitions != 1:
edge_iter_axis = Out.op.axis[0]
feat_reduce_axis = Out.op.reduce_axis[0]
fo, fi = s[Out.op].split(feat_reduce_axis, nparts=num_feat_partitions)
s[Out.op].reorder(fo, edge_iter_axis, fi)
# TODO: parallelize ReshapedSrcFeat and ReshapedDstFeat
return s