in torchmoji/model_def.py [0:0]
def load_specific_weights(model, weight_path, exclude_names=[], extend_embedding=0, verbose=True):
""" Loads model weights from the given file path, excluding any
given layers.
# Arguments:
model: Model whose weights should be loaded.
weight_path: Path to file containing model weights.
exclude_names: List of layer names whose weights should not be loaded.
extend_embedding: Number of new words being added to vocabulary.
verbose: Verbosity flag.
# Raises:
ValueError if the file at weight_path does not exist.
"""
if not exists(weight_path):
raise ValueError('ERROR (load_weights): The weights file at {} does '
'not exist. Refer to the README for instructions.'
.format(weight_path))
if extend_embedding and 'embed' in exclude_names:
raise ValueError('ERROR (load_weights): Cannot extend a vocabulary '
'without loading the embedding weights.')
# Copy only weights from the temporary model that are wanted
# for the specific task (e.g. the Softmax is often ignored)
weights = torch.load(weight_path)
for key, weight in weights.items():
if any(excluded in key for excluded in exclude_names):
if verbose:
print('Ignoring weights for {}'.format(key))
continue
try:
model_w = model.state_dict()[key]
except KeyError:
raise KeyError("Weights had parameters {},".format(key)
+ " but could not find this parameters in model.")
if verbose:
print('Loading weights for {}'.format(key))
# extend embedding layer to allow new randomly initialized words
# if requested. Otherwise, just load the weights for the layer.
if 'embed' in key and extend_embedding > 0:
weight = torch.cat((weight, model_w[NB_TOKENS:, :]), dim=0)
if verbose:
print('Extended vocabulary for embedding layer ' +
'from {} to {} tokens.'.format(
NB_TOKENS, NB_TOKENS + extend_embedding))
try:
model_w.copy_(weight)
except:
print('While copying the weigths named {}, whose dimensions in the model are'
' {} and whose dimensions in the saved file are {}, ...'.format(
key, model_w.size(), weight.size()))
raise