in python/featgraph/module/sddmm.py [0:0]
def __init__(self, adj_scipy, num_row_partitions=1, num_col_partitions=1):
"""Doing 2D graph partitioning during init.
Parameters
----------
adj_scipy : scipy.sparse.coo_matrix or scipy.sparse.csr_matrix
The input scipy sparse matrix
num_row_partitions : int
Number of partitions along the row dimension
num_col_partitions : int
Number of partitions along the col dimension
"""
# Use coo format in SDDMM-like kernels
if adj_scipy.format != 'coo':
adj_scipy_coo = adj_scipy.tocoo()
else:
adj_scipy_coo = adj_scipy
self._num_rows = adj_scipy_coo.shape[0]
self._num_cols = adj_scipy_coo.shape[1]
assert num_row_partitions >= 1, "num_row_partitions should be larger than or equal to 1"
assert num_col_partitions >= 1, "num_col_partitions should be larger than or equal to 1"
self._num_row_partitions = num_row_partitions
self._num_col_partitions = num_col_partitions
# To be updated in self.register
self._target = None
self._ctx = None
self._compute_func = None
self._schedule_func = None
self._register()
# 2D graph partitioning
if self._num_row_partitions > 1 or self._num_col_partitions > 1:
edge_id_list, adj_row_indices, adj_col_indices = self._preprocess_adj(adj_scipy_coo, \
self._num_row_partitions, self._num_col_partitions)
# This is smart; credit to Zihao
self._edge_mapping = np.argsort(edge_id_list)
else:
adj_row_indices = adj_scipy_coo.row
adj_col_indices = adj_scipy_coo.col
self._adj_row_indices = adj_row_indices
self._adj_col_indices = adj_col_indices
self._adj_row_indices_placeholder = te.placeholder(shape=self._adj_row_indices.shape, \
dtype=str(self._adj_row_indices.dtype), name='adj_row_indices_placeholder')
self._adj_col_indices_placeholder = te.placeholder(shape=self._adj_col_indices.shape, \
dtype=str(self._adj_col_indices.dtype), name='adj_col_indices_placeholder')
self._adj_row_indices_tvm = tvm.nd.array(self._adj_row_indices, ctx=self._ctx)
self._adj_col_indices_tvm = tvm.nd.array(self._adj_col_indices, ctx=self._ctx)
# To be updated in self.build
self._func = None
# To be updated in self.run
self.out_tvm = None