def schedule_vanilla_spmm_csr_x86()

in python/featgraph/op/vanilla_spmm.py [0:0]


def schedule_vanilla_spmm_csr_x86(Out):
    s = te.create_schedule([Out.op])

    ReshapedOut = Out.op.input_tensors[0]
    ReshapedSrcFeat = ReshapedOut.op.input_tensors[3]

    # Reorder
    RO = ReshapedOut
    s[RO.op].reorder(RO.op.axis[0], RO.op.axis[1], RO.op.reduce_axis[0], RO.op.axis[2])

    # Parallelize the rows of the sparse matrix
    s[ReshapedSrcFeat.op].parallel(ReshapedSrcFeat.op.axis[1])
    s[ReshapedOut.op].parallel(ReshapedOut.op.axis[1])
    s[Out.op].parallel(Out.op.axis[0])

    return s