def vanilla_sddmm()

in python/featgraph/op/vanilla_sddmm.py [0:0]


def vanilla_sddmm(SrcFeat,
                  DstFeat,
                  Adj_row_indices,
                  Adj_col_indices,
                  num_feat_partitions=1):
    # TODO:apply parallelization in cpu schedule
    # TODO: support tuning both block number and thread number in cuda schedule
    """Compute sampled dense dense matrix multiplication of SrcFeat and DstFeat with Adj matrix as mask.

    Parameters
    ----------
    SrcFeat : tvm.te.Tensor
        2-D with shape [num_rows, feat_len]

    DstFeat : tvm.te.Tensor
        2-D with shape [num_cols, feat_len]

    Adj_row_indices : tvm.te.Tensor
        1-D with shape [nnz] (COO)

    Adj_col_indices : tvm.te.Tensor
        1-D with shape [nnz] (COO)

    num_feat_partitions : int
        Doing feature dimension tiling

    Returns
    -------
    Out : tvm.te.Tensor
        1-D with shape [nnz] (COO)
    """
    feat_len = get_const_tuple(SrcFeat.shape)[1]
    assert get_const_tuple(DstFeat.shape)[1] == feat_len, "dimension mismatch"
    num_edges = get_const_tuple(Adj_row_indices.shape)[0]
    assert get_const_tuple(Adj_col_indices.shape)[0] == num_edges, "dimension mismatch"
    oshape = (num_edges,)

    k = te.reduce_axis((0, feat_len))

    if num_feat_partitions == 1:
        def edgefunc(eid):  # eid: edge id
            return te.sum(SrcFeat[Adj_col_indices[eid], k] * DstFeat[Adj_row_indices[eid], k], axis=k)
    else:
        feat_len_per_partition = feat_len // num_feat_partitions  # we assume feat_len % num_feat_partitions = 0
        num_rows = get_const_tuple(SrcFeat.shape)[0]
        num_cols = get_const_tuple(DstFeat.shape)[0]
        ReshapedSrcFeat = te.compute((num_feat_partitions, num_rows, feat_len_per_partition), \
                                      lambda fo, nn, fi: SrcFeat[nn, fo*feat_len_per_partition + fi], \
                                      name='ReshapedSrcFeat')
        ReshapedDstFeat = te.compute((num_feat_partitions, num_cols, feat_len_per_partition), \
                                      lambda fo, nn, fi: DstFeat[nn, fo*feat_len_per_partition + fi], \
                                      name='ReshapedDstFeat')
        def edgefunc(eid):  # eid: edge id
            return te.sum(ReshapedSrcFeat[k // feat_len_per_partition, Adj_col_indices[eid], k % feat_len_per_partition] \
                          * ReshapedDstFeat[k // feat_len_per_partition, Adj_row_indices[eid], k % feat_len_per_partition], axis=k)

    Out = te.compute(oshape, edgefunc, name='vanilla_sddmm')
    return Out