in ppuda/ghn/nn.py [0:0]
def forward(self, nets_torch, graphs=None, return_embeddings=False, predict_class_layers=True, bn_train=True):
r"""
Predict parameters for a list of >=1 networks.
:param nets_torch: one network or a list of networks, each is based on nn.Module.
In case of evaluation, only one network can be passed.
:param graphs: GraphBatch object in case of training.
For evaluation, graphs can be None and will be constructed on the fly given the nets_torch in this case.
:param return_embeddings: True to return the node embeddings obtained after the last graph propagation step.
return_embeddings=True is used for property prediction experiments.
:param predict_class_layers: default=True predicts all parameters including the classification layers.
predict_class_layers=False is used in fine-tuning experiments.
:param bn_train: default=True sets BN layers in nets_torch into the training mode (required to evaluate predicted parameters)
bn_train=False is used in fine-tuning experiments
:return: nets_torch with predicted parameters and node embeddings if return_embeddings=True
"""
if not self.training:
assert isinstance(nets_torch,
nn.Module) or len(nets_torch) == 1, \
'constructing the graph on the fly is only supported for a single network'
if isinstance(nets_torch, list):
nets_torch = nets_torch[0]
if self.debug_level:
if self.debug_level > 1:
valid_ops = graphs[0].num_valid_nodes(nets_torch)
start_time = time.time() # do not count any debugging steps above
if graphs is None:
graphs = GraphBatch([Graph(nets_torch, ve_cutoff=50 if self.ve else 1)])
graphs.to_device(self.embed.weight.device)
else:
assert graphs is not None, \
'constructing the graph on the fly is only supported in the evaluation mode'
# Find mapping between embeddings and network parameters
param_groups, params_map = self._map_net_params(graphs, nets_torch, self.debug_level > 0)
if self.debug_level or not self.training:
n_params_true = sum([capacity(net)[1] for net in (nets_torch if isinstance(nets_torch, list) else [nets_torch])])
if self.debug_level > 1:
print('\nnumber of learnable parameter tensors: {}, total number of parameters: {}'.format(
valid_ops, n_params_true))
# Obtain initial embeddings for all nodes
x = self.shape_enc(self.embed(graphs.node_feat[:, 0]), params_map, predict_class_layers=predict_class_layers)
# Update node embeddings using a GatedGNN, MLP or another model
x = self.gnn(x, graphs.edges, graphs.node_feat[:, 1])
if self.layernorm:
x = self.ln(x)
# Predict max-sized parameters for a batch of nets using decoders
w = {}
for key, inds in param_groups.items():
if len(inds) == 0:
continue
x_ = x[torch.tensor(inds, device=x.device)]
if key == 'cls_w':
w[key] = self.decoder(x_, (1, 1), class_pred=True)
elif key.startswith('4d'):
sz = tuple(map(int, key.split('-')[1:]))
w[key] = self.decoder(x_, sz, class_pred=False)
else:
w[key] = self.decoder_1d(x_).view(len(inds), 2, -1)#.clone()
if key == 'cls_b':
w[key] = self.bias_class(w[key])
# Transfer predicted parameters (w) to the networks
n_tensors, n_params = 0, 0
for matched, key, w_ind in params_map.values():
if w_ind is None:
continue # e.g. pooling
if not predict_class_layers and key in ['cls_w', 'cls_b']:
continue # do not set the classification parameters when fine-tuning
m, sz, is_w = matched['module'], matched['sz'], matched['is_w']
for it in range(2 if (len(sz) == 1 and is_w) else 1):
if len(sz) == 1:
# separately set for BN/LN biases as they are
# not represented as separate nodes in graphs
w_ = w[key][w_ind][1 - is_w + it]
if it == 1:
assert (type(m) in NormLayers and key == '1d'), \
(type(m), key)
else:
w_ = w[key][w_ind]
sz_set = self._set_params(m, self._tile_params(w_, sz), is_w=is_w & ~it)
n_tensors += 1
n_params += torch.prod(torch.tensor(sz_set))
if not self.training and bn_train:
def bn_set_train(module):
if isinstance(module, nn.BatchNorm2d):
module.track_running_stats = False
module.training = True
nets_torch.apply(bn_set_train) # set BN layers to the training mode to enable evaluation without having running statistics
if self.debug_level and not self.training:
end_time = time.time() - start_time
print('number of parameter tensors predicted using GHN: {}, '
'total parameters predicted: {} ({}), time to predict (on {}): {:.4f} sec'.format(
n_tensors,
n_params,
('matched!' if n_params_true == n_params else 'error! not matched').upper(),
str(x.device).upper(),
end_time))
if self.debug_level > 1:
assert valid_ops == n_tensors, (
'number of learnable tensors ({}) must be the same as the number of predicted tensors ({})'.format(
valid_ops, n_tensors))
if self.debug_level > 2:
print('predicted parameter stats:')
for n, p in nets_torch.named_parameters():
print('{:30s} ({:30s}): min={:.3f} \t max={:.3f} \t mean={:.3f} \t std={:.3f} \t norm={:.3f}'.format(
n[:30],
str(p.shape)[:30],
p.min().item(),
p.max().item(),
p.mean().item(),
p.std().item(),
torch.norm(p).item()))
elif self.debug_level or not self.training:
assert n_params == n_params_true, ('number of predicted ({}) or actual ({}) parameters must match'.format(
n_params, n_params_true))
return (nets_torch, x) if return_embeddings else nets_torch