in shaDow/layers.py [0:0]
def complexity(self, dims_x_l, sizes_subg):
num_nodes = len(sizes_subg)
num_neigh = dims_x_l[-1].num_nodes
assert num_neigh == sizes_subg.sum()
if self.type_pool == 'center':
if self.type_res == 'none':
return Dims_X(num_nodes, dims_x_l[-1].num_feats), 0
else: # regular JK
dims_root_l = [Dims_X(num_nodes, d.num_feats) for d in dims_x_l]
dim_f, ops = self.f_res_complexity(dims_root_l) # pool first, then residue
return Dims_X(num_nodes, dim_f), ops
elif self.type_pool in ['max', 'mean', 'sum']:
ops = dims_x_l[-1].num_nodes * dims_x_l[-1].num_feats
mult = 1 if self.type_res == 'none' else len(dims_x_l)
ops *= mult # we first pool graph
dims_root_l = [Dims_X(num_nodes, d.num_feats) for d in dims_x_l]
_dim_f, ops_res = self.f_res_complexity(dims_root_l)
ops += 2 * ops_res # "2" since one for neighs, and the other for root
elif self.type_pool == 'sort':
if self.type_res == 'none':
ops = 0
else:
_dim, ops = self.f_res_complexity(dims_x_l)
# global_sort_pool: sort is only alone last channel, therefore neglect its complexity
for n in self.nn_pool:
if type(n) == nn.Linear:
ops += np.pool(list(n.weight.shape)) * num_nodes
ops += np.prod(list(self.nn_pool.weight.shape)) * self.k * num_nodes
for n in self.nn:
if type(n) == nn.Linear:
ops += np.prod(list(n.weight.shape)) * num_nodes
dim_f = n.weight.shape[0] # dim 0 is output dim
return Dims_X(num_nodes, dim_f), ops