in code/src/utils.py [0:0]
def reload_model(model, to_reload, attributes):
"""
Reload a previously trained model.
"""
# check parameters sizes
model_params = set(model.state_dict().keys())
to_reload_params = set(to_reload.state_dict().keys())
assert model_params == to_reload_params, (model_params - to_reload_params,
to_reload_params - model_params)
# check attributes
warnings = []
errors = []
for k in attributes:
assert type(k) is tuple or type(k) is str
k, strict = k if type(k) is tuple else (k, True)
if getattr(model, k, None) is None:
errors.append('- Attribute "%s" not found in the current model' % k)
if getattr(to_reload, k, None) is None:
errors.append('- Attribute "%s" not found in the model to reload' % k)
if getattr(model, k, None) != getattr(to_reload, k, None):
message = ('- Attribute "%s" differs between the current model (%s) '
'and the one to reload (%s)'
% (k, str(getattr(model, k)), str(getattr(to_reload, k))))
(errors if strict else warnings).append(message)
if len(warnings) > 0:
logger.warning('Different parameters:\n%s' % '\n'.join(warnings))
if len(errors) > 0:
logger.error('Incompatible parameters:\n%s' % '\n'.join(errors))
exit()
# copy saved parameters
for k in model.state_dict().keys():
if model.state_dict()[k].size() != to_reload.state_dict()[k].size():
raise Exception("Expected tensor {} of size {}, but got {}".format(
k, model.state_dict()[k].size(),
to_reload.state_dict()[k].size()
))
model.state_dict()[k].copy_(to_reload.state_dict()[k])