in src/utils.py [0:0]
def reload_model(model, to_reload, attributes=None):
"""
Reload a previously trained model.
"""
# reload the model
assert os.path.isfile(to_reload)
to_reload = torch.load(to_reload)
# check parameters sizes
model_params = set(model.state_dict().keys())
to_reload_params = set(to_reload.state_dict().keys())
assert model_params == to_reload_params, (model_params - to_reload_params,
to_reload_params - model_params)
# check attributes
attributes = [] if attributes is None else attributes
for k in attributes:
if getattr(model, k, None) is None:
raise Exception('Attribute "%s" not found in the current model' % k)
if getattr(to_reload, k, None) is None:
raise Exception('Attribute "%s" not found in the model to reload' % k)
if getattr(model, k) != getattr(to_reload, k):
raise Exception('Attribute "%s" differs between the current model (%s) '
'and the one to reload (%s)'
% (k, str(getattr(model, k)), str(getattr(to_reload, k))))
# copy saved parameters
for k in model.state_dict().keys():
if model.state_dict()[k].size() != to_reload.state_dict()[k].size():
raise Exception("Expected tensor {} of size {}, but got {}".format(
k, model.state_dict()[k].size(),
to_reload.state_dict()[k].size()
))
model.state_dict()[k].copy_(to_reload.state_dict()[k])