def forward()

in ppuda/ghn/nn.py [0:0]


    def forward(self, nets_torch, graphs=None, return_embeddings=False, predict_class_layers=True, bn_train=True):
        r"""
        Predict parameters for a list of >=1 networks.
        :param nets_torch: one network or a list of networks, each is based on nn.Module.
                           In case of evaluation, only one network can be passed.
        :param graphs: GraphBatch object in case of training.
                       For evaluation, graphs can be None and will be constructed on the fly given the nets_torch in this case.
        :param return_embeddings: True to return the node embeddings obtained after the last graph propagation step.
                                  return_embeddings=True is used for property prediction experiments.
        :param predict_class_layers: default=True predicts all parameters including the classification layers.
                                     predict_class_layers=False is used in fine-tuning experiments.
        :param bn_train: default=True sets BN layers in nets_torch into the training mode (required to evaluate predicted parameters)
                        bn_train=False is used in fine-tuning experiments
        :return: nets_torch with predicted parameters and node embeddings if return_embeddings=True
        """

        if not self.training:
            assert isinstance(nets_torch,
                              nn.Module) or len(nets_torch) == 1, \
                'constructing the graph on the fly is only supported for a single network'

            if isinstance(nets_torch, list):
                nets_torch = nets_torch[0]

            if self.debug_level:
                if self.debug_level > 1:
                    valid_ops = graphs[0].num_valid_nodes(nets_torch)
                start_time = time.time()  # do not count any debugging steps above

            if graphs is None:
                graphs = GraphBatch([Graph(nets_torch, ve_cutoff=50 if self.ve else 1)])
                graphs.to_device(self.embed.weight.device)

        else:
            assert graphs is not None, \
                'constructing the graph on the fly is only supported in the evaluation mode'

        # Find mapping between embeddings and network parameters
        param_groups, params_map = self._map_net_params(graphs, nets_torch, self.debug_level > 0)

        if self.debug_level or not self.training:
            n_params_true = sum([capacity(net)[1] for net in (nets_torch if isinstance(nets_torch, list) else [nets_torch])])
            if self.debug_level > 1:
                print('\nnumber of learnable parameter tensors: {}, total number of parameters: {}'.format(
                    valid_ops, n_params_true))

        # Obtain initial embeddings for all nodes
        x = self.shape_enc(self.embed(graphs.node_feat[:, 0]), params_map, predict_class_layers=predict_class_layers)

        # Update node embeddings using a GatedGNN, MLP or another model
        x = self.gnn(x, graphs.edges, graphs.node_feat[:, 1])

        if self.layernorm:
            x = self.ln(x)

        # Predict max-sized parameters for a batch of nets using decoders
        w = {}
        for key, inds in param_groups.items():
            if len(inds) == 0:
                continue
            x_ = x[torch.tensor(inds, device=x.device)]
            if key == 'cls_w':
                w[key] = self.decoder(x_, (1, 1), class_pred=True)
            elif key.startswith('4d'):
                sz = tuple(map(int, key.split('-')[1:]))
                w[key] = self.decoder(x_, sz, class_pred=False)
            else:
                w[key] = self.decoder_1d(x_).view(len(inds), 2, -1)#.clone()
                if key == 'cls_b':
                    w[key] = self.bias_class(w[key])

        # Transfer predicted parameters (w) to the networks
        n_tensors, n_params = 0, 0
        for matched, key, w_ind in params_map.values():

            if w_ind is None:
                continue  # e.g. pooling

            if not predict_class_layers and key in ['cls_w', 'cls_b']:
                continue  # do not set the classification parameters when fine-tuning

            m, sz, is_w = matched['module'], matched['sz'], matched['is_w']
            for it in range(2 if (len(sz) == 1 and is_w) else 1):

                if len(sz) == 1:
                    # separately set for BN/LN biases as they are
                    # not represented as separate nodes in graphs
                    w_ = w[key][w_ind][1 - is_w + it]
                    if it == 1:
                        assert (type(m) in NormLayers and key == '1d'), \
                            (type(m), key)
                else:
                    w_ = w[key][w_ind]

                sz_set = self._set_params(m, self._tile_params(w_, sz), is_w=is_w & ~it)
                n_tensors += 1
                n_params += torch.prod(torch.tensor(sz_set))

        if not self.training and bn_train:

            def bn_set_train(module):
                if isinstance(module, nn.BatchNorm2d):
                    module.track_running_stats = False
                    module.training = True

            nets_torch.apply(bn_set_train)  # set BN layers to the training mode to enable evaluation without having running statistics

        if self.debug_level and not self.training:

            end_time = time.time() - start_time

            print('number of parameter tensors predicted using GHN: {}, '
                  'total parameters predicted: {} ({}), time to predict (on {}): {:.4f} sec'.format(
                n_tensors,
                n_params,
                ('matched!' if n_params_true == n_params else 'error! not matched').upper(),
                str(x.device).upper(),
                end_time))

            if self.debug_level > 1:
                assert valid_ops == n_tensors, (
                    'number of learnable tensors ({}) must be the same as the number of predicted tensors ({})'.format(
                        valid_ops, n_tensors))


            if self.debug_level > 2:
                print('predicted parameter stats:')
                for n, p in nets_torch.named_parameters():
                    print('{:30s} ({:30s}): min={:.3f} \t max={:.3f} \t mean={:.3f} \t std={:.3f} \t norm={:.3f}'.format(
                        n[:30],
                        str(p.shape)[:30],
                        p.min().item(),
                        p.max().item(),
                        p.mean().item(),
                        p.std().item(),
                        torch.norm(p).item()))
        elif self.debug_level or not self.training:
            assert n_params == n_params_true, ('number of predicted ({}) or actual ({}) parameters must match'.format(
                n_params, n_params_true))

        return (nets_torch, x) if return_embeddings else nets_torch