def complexity()

in shaDow/layers.py [0:0]


    def complexity(self, dims_x_l, sizes_subg):
        num_nodes = len(sizes_subg)
        num_neigh = dims_x_l[-1].num_nodes
        assert num_neigh == sizes_subg.sum()
        if self.type_pool == 'center':
            if self.type_res == 'none':
                return Dims_X(num_nodes, dims_x_l[-1].num_feats), 0
            else:       # regular JK
                dims_root_l = [Dims_X(num_nodes, d.num_feats) for d in dims_x_l]
                dim_f, ops = self.f_res_complexity(dims_root_l)    # pool first, then residue
                return Dims_X(num_nodes, dim_f), ops
        elif self.type_pool in ['max', 'mean', 'sum']:
            ops = dims_x_l[-1].num_nodes * dims_x_l[-1].num_feats
            mult = 1 if self.type_res == 'none' else len(dims_x_l)
            ops *= mult     # we first pool graph
            dims_root_l = [Dims_X(num_nodes, d.num_feats) for d in dims_x_l]
            _dim_f, ops_res = self.f_res_complexity(dims_root_l)
            ops += 2 * ops_res      # "2" since one for neighs, and the other for root
        elif self.type_pool == 'sort':
            if self.type_res == 'none':
                ops = 0
            else:
                _dim, ops = self.f_res_complexity(dims_x_l)
            # global_sort_pool: sort is only alone last channel, therefore neglect its complexity
            for n in self.nn_pool:
                if type(n) == nn.Linear:
                    ops += np.pool(list(n.weight.shape)) * num_nodes
            ops += np.prod(list(self.nn_pool.weight.shape)) * self.k * num_nodes
        for n in self.nn:
            if type(n) == nn.Linear:
                ops += np.prod(list(n.weight.shape)) * num_nodes
                dim_f = n.weight.shape[0]       # dim 0 is output dim
        return Dims_X(num_nodes, dim_f), ops