def vanilla_spmm_csr_cuda()

in python/featgraph/op/vanilla_spmm.py [0:0]


def vanilla_spmm_csr_cuda(SrcFeat,
                          Adj_indptr,
                          Adj_indices,
                          Adj_vals):
    """Compute sparse-dense matrix multiplication of Adj and SrcFeat on cuda.
    This implementation does not transform the layout of SrcFeat.

    Parameters
    ----------
    SrcFeat : tvm.te.Tensor
        2-D with shape [num_src_vertices, feat_len]

    Adj_indptr : tvm.te.Tensor
        1-D with shape [num_dst_vertices + 1] (CSR)

    Adj_indices : tvm.te.Tensor
        1-D with shape [nnz] (CSR)

    Adj_vals : tvm.te.Tensor
        1-D with shape [nnz] (CSR)

    Returns
    -------
    Out : tvm.te.Tensor
        2-D with shape [num_dst_vertices, feat_len]
    """
    assert Adj_indices.shape[0].value == Adj_vals.shape[0].value
    num_src_vertices, feat_len = get_const_tuple(SrcFeat.shape)
    num_dst_vertices = Adj_indptr.shape[0].value - 1
    oshape = (num_dst_vertices, feat_len)

    def msgfunc(row, ff):
        row_start = Adj_indptr[row]
        row_end = Adj_indptr[row + 1]
        row_num_elems = row_end - row_start
        elem_idx = te.reduce_axis((0, row_num_elems), name="elem_idx")
        adj_val = Adj_vals[row_start + elem_idx]
        feat_val = SrcFeat[Adj_indices[row_start + elem_idx], ff]
        return te.sum(adj_val * feat_val, axis=elem_idx)

    Out = te.compute(oshape, msgfunc, name='Out')

    return Out