def load_state_dict_to_model()

in lmgvp/transfer.py [0:0]


def load_state_dict_to_model(model, state_dict):
    """Initialize a model with parameters in `state_dict` (inplace)
    from a pretrained model with slightly different architecture.

    Args:
        model: Torch model
        state_dict: Dictionary containing weight for each layer of the `model`

    Returns:
        input `model` where layer weights have been updated based on `state_dict`

    """
    own_state = model.state_dict()
    print("model own state keys:", len(own_state))
    print("state_dict keys:", len(state_dict))
    keys_loaded = 0
    for name, param in state_dict.items():
        if name not in own_state:
            continue
        if isinstance(param, Parameter):
            # backwards compatibility for serialized parameters
            param = param.data
        own_state[name].copy_(param)
        keys_loaded += 1
    print("keys loaded into model:", keys_loaded)