in python/featgraph/op/vanilla_spmm.py [0:0]
def schedule_vanilla_spmm_dds_x86(Out):
s = te.create_schedule([Out.op])
ReshapedOut = Out.op.input_tensors[0]
Intermediate = ReshapedOut.op.input_tensors[0]
ReshapedSrcFeat = Intermediate.op.input_tensors[3]
I = Intermediate
RO = ReshapedOut
s[I.op].reorder(I.op.axis[0], I.op.axis[1], I.op.axis[2], I.op.reduce_axis[0], I.op.axis[3])
s[RO.op].reorder(RO.op.axis[0], RO.op.reduce_axis[0], RO.op.axis[1], RO.op.axis[2])
s[I.op].compute_at(s[RO], RO.op.reduce_axis[0])
# Parallelize the rows of the sparse matrix
s[ReshapedSrcFeat.op].parallel(ReshapedSrcFeat.op.axis[1])
s[Intermediate.op].parallel(Intermediate.op.axis[2])
s[ReshapedOut.op].parallel(ReshapedOut.op.axis[1])
s[Out.op].parallel(Out.op.axis[0])
return s