def reload_model()

in src/utils.py [0:0]


def reload_model(model, to_reload, attributes=None):
    """
    Reload a previously trained model.
    """
    # reload the model
    assert os.path.isfile(to_reload)
    to_reload = torch.load(to_reload)

    # check parameters sizes
    model_params = set(model.state_dict().keys())
    to_reload_params = set(to_reload.state_dict().keys())
    assert model_params == to_reload_params, (model_params - to_reload_params,
                                              to_reload_params - model_params)

    # check attributes
    attributes = [] if attributes is None else attributes
    for k in attributes:
        if getattr(model, k, None) is None:
            raise Exception('Attribute "%s" not found in the current model' % k)
        if getattr(to_reload, k, None) is None:
            raise Exception('Attribute "%s" not found in the model to reload' % k)
        if getattr(model, k) != getattr(to_reload, k):
            raise Exception('Attribute "%s" differs between the current model (%s) '
                            'and the one to reload (%s)'
                            % (k, str(getattr(model, k)), str(getattr(to_reload, k))))

    # copy saved parameters
    for k in model.state_dict().keys():
        if model.state_dict()[k].size() != to_reload.state_dict()[k].size():
            raise Exception("Expected tensor {} of size {}, but got {}".format(
                k, model.state_dict()[k].size(),
                to_reload.state_dict()[k].size()
            ))
        model.state_dict()[k].copy_(to_reload.state_dict()[k])