def schedule_vanilla_spmm_dds_x86()

in python/featgraph/op/vanilla_spmm.py [0:0]


def schedule_vanilla_spmm_dds_x86(Out):
    s = te.create_schedule([Out.op])

    ReshapedOut = Out.op.input_tensors[0]
    Intermediate = ReshapedOut.op.input_tensors[0]
    ReshapedSrcFeat = Intermediate.op.input_tensors[3]

    I = Intermediate
    RO = ReshapedOut
    s[I.op].reorder(I.op.axis[0], I.op.axis[1], I.op.axis[2], I.op.reduce_axis[0], I.op.axis[3])
    s[RO.op].reorder(RO.op.axis[0], RO.op.reduce_axis[0], RO.op.axis[1], RO.op.axis[2])
    s[I.op].compute_at(s[RO], RO.op.reduce_axis[0])

    # Parallelize the rows of the sparse matrix
    s[ReshapedSrcFeat.op].parallel(ReshapedSrcFeat.op.axis[1])
    s[Intermediate.op].parallel(Intermediate.op.axis[2])
    s[ReshapedOut.op].parallel(ReshapedOut.op.axis[1])
    s[Out.op].parallel(Out.op.axis[0])

    return s