def reload_model()

in code/src/utils.py [0:0]


def reload_model(model, to_reload, attributes):
    """
    Reload a previously trained model.
    """
    # check parameters sizes
    model_params = set(model.state_dict().keys())
    to_reload_params = set(to_reload.state_dict().keys())
    assert model_params == to_reload_params, (model_params - to_reload_params,
                                              to_reload_params - model_params)

    # check attributes
    warnings = []
    errors = []
    for k in attributes:
        assert type(k) is tuple or type(k) is str
        k, strict = k if type(k) is tuple else (k, True)
        if getattr(model, k, None) is None:
            errors.append('- Attribute "%s" not found in the current model' % k)
        if getattr(to_reload, k, None) is None:
            errors.append('- Attribute "%s" not found in the model to reload' % k)
        if getattr(model, k, None) != getattr(to_reload, k, None):
            message = ('- Attribute "%s" differs between the current model (%s) '
                       'and the one to reload (%s)'
                       % (k, str(getattr(model, k)), str(getattr(to_reload, k))))
            (errors if strict else warnings).append(message)
    if len(warnings) > 0:
        logger.warning('Different parameters:\n%s' % '\n'.join(warnings))
    if len(errors) > 0:
        logger.error('Incompatible parameters:\n%s' % '\n'.join(errors))
        exit()

    # copy saved parameters
    for k in model.state_dict().keys():
        if model.state_dict()[k].size() != to_reload.state_dict()[k].size():
            raise Exception("Expected tensor {} of size {}, but got {}".format(
                k, model.state_dict()[k].size(),
                to_reload.state_dict()[k].size()
            ))
        model.state_dict()[k].copy_(to_reload.state_dict()[k])