in python/featgraph/op/vanilla_sddmm.py [0:0]
def vanilla_sddmm(SrcFeat,
DstFeat,
Adj_row_indices,
Adj_col_indices,
num_feat_partitions=1):
# TODO:apply parallelization in cpu schedule
# TODO: support tuning both block number and thread number in cuda schedule
"""Compute sampled dense dense matrix multiplication of SrcFeat and DstFeat with Adj matrix as mask.
Parameters
----------
SrcFeat : tvm.te.Tensor
2-D with shape [num_rows, feat_len]
DstFeat : tvm.te.Tensor
2-D with shape [num_cols, feat_len]
Adj_row_indices : tvm.te.Tensor
1-D with shape [nnz] (COO)
Adj_col_indices : tvm.te.Tensor
1-D with shape [nnz] (COO)
num_feat_partitions : int
Doing feature dimension tiling
Returns
-------
Out : tvm.te.Tensor
1-D with shape [nnz] (COO)
"""
feat_len = get_const_tuple(SrcFeat.shape)[1]
assert get_const_tuple(DstFeat.shape)[1] == feat_len, "dimension mismatch"
num_edges = get_const_tuple(Adj_row_indices.shape)[0]
assert get_const_tuple(Adj_col_indices.shape)[0] == num_edges, "dimension mismatch"
oshape = (num_edges,)
k = te.reduce_axis((0, feat_len))
if num_feat_partitions == 1:
def edgefunc(eid): # eid: edge id
return te.sum(SrcFeat[Adj_col_indices[eid], k] * DstFeat[Adj_row_indices[eid], k], axis=k)
else:
feat_len_per_partition = feat_len // num_feat_partitions # we assume feat_len % num_feat_partitions = 0
num_rows = get_const_tuple(SrcFeat.shape)[0]
num_cols = get_const_tuple(DstFeat.shape)[0]
ReshapedSrcFeat = te.compute((num_feat_partitions, num_rows, feat_len_per_partition), \
lambda fo, nn, fi: SrcFeat[nn, fo*feat_len_per_partition + fi], \
name='ReshapedSrcFeat')
ReshapedDstFeat = te.compute((num_feat_partitions, num_cols, feat_len_per_partition), \
lambda fo, nn, fi: DstFeat[nn, fo*feat_len_per_partition + fi], \
name='ReshapedDstFeat')
def edgefunc(eid): # eid: edge id
return te.sum(ReshapedSrcFeat[k // feat_len_per_partition, Adj_col_indices[eid], k % feat_len_per_partition] \
* ReshapedDstFeat[k // feat_len_per_partition, Adj_row_indices[eid], k % feat_len_per_partition], axis=k)
Out = te.compute(oshape, edgefunc, name='vanilla_sddmm')
return Out