in ppuda/ghn/nn.py [0:0]
def _map_net_params(self, graphs, nets_torch, sanity_check=False):
r"""
Matches the parameters in the models with the nodes in the graph.
Performs additional steps.
:param graphs: GraphBatch object
:param nets_torch: a single neural network of a list
:param sanity_check:
:return: mapping, params_map
"""
mapping = {}
params_map = {}
nets_torch = [nets_torch] if type(nets_torch) not in [tuple, list] else nets_torch
for b, (node_info, net) in enumerate(zip(graphs.node_info, nets_torch)):
target_modules = named_layered_modules(net)
param_ind = torch.sum(graphs.n_nodes[:b]).item()
for cell_id in range(len(node_info)):
matched_names = []
for (node_ind, param_name, name, sz, last_weight, last_bias) in node_info[cell_id]:
matched = []
for m in target_modules[cell_id]:
if m['param_name'].startswith(param_name):
matched.append(m)
if not sanity_check:
break
if len(matched) > 1:
raise ValueError(cell_id, node_ind, param_name, name, [
(t, (m.weight if is_w else m.bias).shape) for
t, m, is_w in matched])
elif len(matched) == 0:
if sz is not None:
params_map[param_ind + node_ind] = ({'sz': sz}, None, None)
if sanity_check:
for pattern in ['input', 'sum', 'concat', 'pool', 'glob_avg', 'msa', 'cse']:
good = name.find(pattern) >= 0
if good:
break
assert good, \
(cell_id, param_name, name,
node_info[cell_id],
target_modules[cell_id])
else:
matched_names.append(matched[0]['param_name'])
sz = matched[0]['sz']
if len(sz) == 1:
key = 'cls_b' if last_bias else '1d'
elif last_weight:
key = 'cls_w'
else:
key = '4d-%d-%d' % ((1, 1) if len(sz) == 2 else sz[2:])
if key not in mapping:
mapping[key] = []
params_map[param_ind + node_ind] = (matched[0], key, len(mapping[key]))
mapping[key].append(param_ind + node_ind)
assert len(matched_names) == len(set(matched_names)), (
'all matched names must be unique to avoid predicting the same paramters for different moduels',
len(matched_names), len(set(matched_names)))
matched_names = set(matched_names)
# Prune redundant ops in Network by setting their params to None
for m in target_modules[cell_id]:
if m['is_w'] and m['param_name'] not in matched_names:
m['module'].weight = None
if hasattr(m['module'], 'bias') and m['module'].bias is not None:
m['module'].bias = None
return mapping, params_map