in python/featgraph/module/spmm.py [0:0]
def build(self, input_placeholders, compute_args, schedule_args):
"""Build tvm func, update self._func inplace.
Parameters
----------
input_placeholders : list of te.placeholder
The required input tvm placeholders other than adj (which has been passed in during self.init)
compute_args : dict
Arguments required for compute_func, e.g., num_feat_partitions
schedule_args : dict
Arguments required for schedule_func, e.g., num_cuda_blocks
"""
if self._num_col_partitions > 1:
out_placeholder = self._compute_func(*input_placeholders, self._adj_s1_pos_placeholder, \
self._adj_s1_idx_placeholder, self._adj_vals_placeholder,
self._adj_d1_size, self._adj_d2_size, **compute_args) # use ** to unpack dict into kwargs
s = self._schedule_func(out_placeholder, **schedule_args)
self._func = tvm.build(s, [*input_placeholders, self._adj_s1_pos_placeholder, \
self._adj_s1_idx_placeholder, self._adj_vals_placeholder, out_placeholder], target=self._target)
self.out_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(out_placeholder.shape), \
dtype=str(out_placeholder.dtype)), ctx=self._ctx)
else:
out_placeholder = self._compute_func(*input_placeholders, self._adj_indptr_placeholder, \
self._adj_indices_placeholder, self._adj_vals_placeholder, **compute_args)
s = self._schedule_func(out_placeholder, **schedule_args)
self._func = tvm.build(s, [*input_placeholders, self._adj_indptr_placeholder, \
self._adj_indices_placeholder, self._adj_vals_placeholder, out_placeholder], target=self._target)
self.out_tvm = tvm.nd.array(np.zeros(shape=get_const_tuple(out_placeholder.shape), \
dtype=str(out_placeholder.dtype)), ctx=self._ctx)