def vanilla_spmm_dds_x86()

in python/featgraph/op/vanilla_spmm.py [0:0]


def vanilla_spmm_dds_x86(SrcFeat,
                         Adj_s1_pos,
                         Adj_s1_idx,
                         Adj_vals,
                         d1_size,
                         d2_size,
                         num_feat_partitions=1):
    """Compute sparse-dense matrix multiplication of Adj and SrcFeat on x86.
    This implementation applies both feature dimension partitioning and 1D graph partitioning.
    1D graph partitioning transforms the csr Adj matrix into dense-dense-sparse (DDS) format.

    Parameters
    ----------
    SrcFeat : tvm.te.Tensor
        2-D with shape [num_src_vertices, feat_len]

    Adj_s1_pos : tvm.te.Tensor
        1-D with shape [d1_size * d2_size] (DDS)

    Adj_s1_idx : tvm.te.Tensor
        1-D with shape [nnz] (DDS)

    Adj_vals : tvm.te.Tensor
        1-D with shape [nnz] (DDS)

    d1_size : int
        Number of src vertex partitions

    d2_size : int
        num_dst_vertices + 1

    num_feat_partitions : int
        Doing feature dimension tiling

    Returns
    -------
    Out : tvm.te.Tensor
        2-D with shape [num_dst_vertices, feat_len]
    """
    assert d1_size * d2_size == Adj_s1_pos.shape[0].value
    assert Adj_s1_idx.shape[0].value == Adj_vals.shape[0].value
    num_src_vertices, feat_len = get_const_tuple(SrcFeat.shape)
    num_src_vertex_partitions = d1_size
    num_dst_vertices = d2_size - 1
    oshape = (num_dst_vertices, feat_len)

    feat_len_per_partition = feat_len // num_feat_partitions  # we assume feat_len % num_feat_partitions = 0
    num_src_vertices_per_partition = (num_src_vertices + num_src_vertex_partitions - 1) // num_src_vertex_partitions

    ReshapedSrcFeat = te.compute((num_feat_partitions, num_src_vertices, feat_len_per_partition), \
        lambda fo, nn, fi: SrcFeat[nn, fo * feat_len_per_partition + fi], name='ReshapedSrcFeat')

    def msgfunc(fo, src_vertex_partition_idx, row, fi):
        row_start = Adj_s1_pos[src_vertex_partition_idx * d2_size + row]
        row_end = Adj_s1_pos[src_vertex_partition_idx * d2_size + row + 1]
        row_num_elems = row_end - row_start
        elem_idx = te.reduce_axis((0, row_num_elems), name="elem_idx")
        adj_val = Adj_vals[row_start + elem_idx]
        feat_val = ReshapedSrcFeat[fo, \
                                   Adj_s1_idx[row_start + elem_idx] + src_vertex_partition_idx * num_src_vertices_per_partition, \
                                   fi]
        return te.sum(adj_val * feat_val, axis=elem_idx)

    Intermediate = te.compute((num_feat_partitions, num_src_vertex_partitions, num_dst_vertices, feat_len_per_partition), \
        msgfunc, name='Intermediate')

    k = te.reduce_axis((0, num_src_vertex_partitions), name='src_vertex_partition_reduce')
    ReshapedOut = te.compute((num_feat_partitions, num_dst_vertices, feat_len_per_partition),
        lambda fo, nn, fi: te.sum(Intermediate[fo, k, nn, fi], axis=k), \
        name='ReshapedOut')

    Out = te.compute(oshape, \
        lambda nn, ff: ReshapedOut[ff // feat_len_per_partition, nn, ff % feat_len_per_partition], \
        name='Out')

    return Out