def _map_net_params()

in ppuda/ghn/nn.py [0:0]


    def _map_net_params(self, graphs, nets_torch, sanity_check=False):
        r"""
        Matches the parameters in the models with the nodes in the graph.
        Performs additional steps.
        :param graphs: GraphBatch object
        :param nets_torch: a single neural network of a list
        :param sanity_check:
        :return: mapping, params_map
        """
        mapping = {}
        params_map = {}

        nets_torch = [nets_torch] if type(nets_torch) not in [tuple, list] else nets_torch

        for b, (node_info, net) in enumerate(zip(graphs.node_info, nets_torch)):
            target_modules = named_layered_modules(net)

            param_ind = torch.sum(graphs.n_nodes[:b]).item()

            for cell_id in range(len(node_info)):
                matched_names = []
                for (node_ind, param_name, name, sz, last_weight, last_bias) in node_info[cell_id]:

                    matched = []
                    for m in target_modules[cell_id]:
                        if m['param_name'].startswith(param_name):
                            matched.append(m)
                            if not sanity_check:
                                break
                    if len(matched) > 1:
                        raise ValueError(cell_id, node_ind, param_name, name, [
                            (t, (m.weight if is_w else m.bias).shape) for
                            t, m, is_w in matched])
                    elif len(matched) == 0:
                        if sz is not None:
                            params_map[param_ind + node_ind] = ({'sz': sz}, None, None)

                        if sanity_check:
                            for pattern in ['input', 'sum', 'concat', 'pool', 'glob_avg', 'msa', 'cse']:
                                good = name.find(pattern) >= 0
                                if good:
                                    break
                            assert good, \
                                (cell_id, param_name, name,
                                 node_info[cell_id],
                                 target_modules[cell_id])
                    else:
                        matched_names.append(matched[0]['param_name'])
                        sz = matched[0]['sz']
                        if len(sz) == 1:
                            key = 'cls_b' if last_bias else '1d'
                        elif last_weight:
                            key = 'cls_w'
                        else:
                            key = '4d-%d-%d' % ((1, 1) if len(sz) == 2 else sz[2:])
                        if key not in mapping:
                            mapping[key] = []
                        params_map[param_ind + node_ind] = (matched[0], key, len(mapping[key]))
                        mapping[key].append(param_ind + node_ind)

                assert len(matched_names) == len(set(matched_names)), (
                    'all matched names must be unique to avoid predicting the same paramters for different moduels',
                    len(matched_names), len(set(matched_names)))
                matched_names = set(matched_names)

                # Prune redundant ops in Network by setting their params to None
                for m in target_modules[cell_id]:
                    if m['is_w'] and m['param_name'] not in matched_names:
                        m['module'].weight = None
                        if hasattr(m['module'], 'bias') and m['module'].bias is not None:
                            m['module'].bias = None

        return mapping, params_map