in python/featgraph/module/spmm.py [0:0]
def __init__(self, adj_scipy, num_col_partitions=1):
"""Doing 1D graph partitioning (src vertex partitioning) during init.
Parameters
----------
adj_scipy : scipy.sparse.coo_matrix or scipy.sparse.csr_matrix
The input scipy sparse matrix
num_col_partitions : int
Number of partitions along the col dimension (src vertices)
"""
# Use csr format in SpMM-like kernels
if adj_scipy.format != 'csr':
adj_scipy_csr = adj_scipy.tocsr()
else:
adj_scipy_csr = adj_scipy
self._num_rows = adj_scipy_csr.shape[0]
self._num_cols = adj_scipy_csr.shape[1]
assert num_col_partitions >= 1, "num_col_partitions should be larger than or equal to 1"
self._num_col_partitions = num_col_partitions
# To be updated in self._register
self._target = None
self._ctx = None
self._compute_func = None
self._schedule_func = None
self._register()
# 1D graph partitioning
if self._num_col_partitions > 1:
adj_s1_pos, adj_s1_idx, adj_vals = self._preprocess_adj(adj_scipy_csr, self._num_col_partitions)
self._adj_s1_pos = adj_s1_pos
self._adj_s1_idx = adj_s1_idx
self._adj_vals = adj_vals
self._adj_s1_pos_placeholder = te.placeholder(shape=self._adj_s1_pos.shape, \
dtype=str(self._adj_s1_pos.dtype), name='adj_s1_pos_placeholder')
self._adj_s1_idx_placeholder = te.placeholder(shape=self._adj_s1_idx.shape, \
dtype=str(self._adj_s1_idx.dtype), name='adj_s1_idx_placeholder')
self._adj_vals_placeholder = te.placeholder(shape=self._adj_vals.shape, \
dtype=str(self._adj_vals.dtype), name='adj_vals_placeholder')
self._adj_s1_pos_tvm = tvm.nd.array(self._adj_s1_pos, ctx=self._ctx)
self._adj_s1_idx_tvm = tvm.nd.array(self._adj_s1_idx, ctx=self._ctx)
self._adj_vals_tvm = tvm.nd.array(self._adj_vals, ctx=self._ctx)
self._adj_d1_size = self._num_col_partitions
self._adj_d2_size = self._num_rows + 1
else:
self._adj_indptr = adj_scipy_csr.indptr
self._adj_indices = adj_scipy_csr.indices
self._adj_vals = adj_scipy_csr.data
self._adj_indptr_placeholder = te.placeholder(shape=self._adj_indptr.shape, \
dtype=str(self._adj_indptr.dtype), name='adj_indptr_placeholder')
self._adj_indices_placeholder = te.placeholder(shape=self._adj_indices.shape, \
dtype=str(self._adj_indices.dtype), name='adj_indices_placeholder')
self._adj_vals_placeholder = te.placeholder(shape=self._adj_vals.shape, \
dtype=str(self._adj_vals.dtype), name='adj_vals_placeholder')
self._adj_indptr_tvm = tvm.nd.array(self._adj_indptr, ctx=self._ctx)
self._adj_indices_tvm = tvm.nd.array(self._adj_indices, ctx=self._ctx)
self._adj_vals_tvm = tvm.nd.array(self._adj_vals, ctx=self._ctx)
# To be updated in self.build
self._func = None
# To be updated in self.run
self.out_tvm = None