def load_specific_weights()

in torchmoji/model_def.py [0:0]


def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True):
    """ Loads model weights from the given file path, excluding any
        given layers.

    # Arguments:
        model: Model whose weights should be loaded.
        weight_path: Path to file containing model weights.
        exclude_names: List of layer names whose weights should not be loaded.
        extend_embedding: Number of new words being added to vocabulary.
        verbose: Verbosity flag.

    # Raises:
        ValueError if the file at weight_path does not exist.
    """
    if not exists(weight_path):
        raise ValueError('ERROR (load_weights): The weights file at {} does '
                         'not exist. Refer to the README for instructions.'
                         .format(weight_path))

    if extend_embedding and 'embed' in exclude_names:
        raise ValueError('ERROR (load_weights): Cannot extend a vocabulary '
                         'without loading the embedding weights.')

    # Copy only weights from the temporary model that are wanted
    # for the specific task (e.g. the Softmax is often ignored)
    weights = torch.load(weight_path)
    for key, weight in weights.items():
        if any(excluded in key for excluded in exclude_names):
            if verbose:
                print('Ignoring weights for {}'.format(key))
            continue

        try:
            model_w = model.state_dict()[key]
        except KeyError:
            raise KeyError("Weights had parameters {},".format(key)
                           + " but could not find this parameters in model.")

        if verbose:
            print('Loading weights for {}'.format(key))

        # extend embedding layer to allow new randomly initialized words
        # if requested. Otherwise, just load the weights for the layer.
        if 'embed' in key and extend_embedding > 0:
            weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0)
            if verbose:
                print('Extended vocabulary for embedding layer ' +
                      'from {} to {} tokens.'.format(
                        NB_TOKENS, NB_TOKENS + extend_embedding))
        try:
            model_w.copy_(weight)
        except:
            print('While copying the weigths named {}, whose dimensions in the model are'
                  ' {} and whose dimensions in the saved file are {}, ...'.format(
                        key, model_w.size(), weight.size()))
            raise