in src/fairseq/fairseq/models/bart/model.py [0:0]
def upgrade_state_dict_named(self, state_dict, name):
super().upgrade_state_dict_named(state_dict, name)
prefix = name + '.' if name != '' else ''
current_head_names = [] if not hasattr(self, 'classification_heads') else \
self.classification_heads.keys()
# Handle new classification heads present in the state dict.
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + 'classification_heads.'):
continue
head_name = k[len(prefix + 'classification_heads.'):].split('.')[0]
num_classes = state_dict[prefix + 'classification_heads.' + head_name + '.out_proj.weight'].size(0)
inner_dim = state_dict[prefix + 'classification_heads.' + head_name + '.dense.weight'].size(0)
if getattr(self.args, 'load_checkpoint_heads', False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
'deleting classification head ({}) from checkpoint '
'not present in current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes != self.classification_heads[head_name].out_proj.out_features
or inner_dim != self.classification_heads[head_name].dense.out_features
):
logger.warning(
'deleting classification head ({}) from checkpoint '
'with different dimensions than current model: {}'.format(head_name, k)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
def truncate_emb(key):
if key in state_dict:
state_dict[key] = state_dict[key][:-1, :]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
loaded_dict_size = state_dict['encoder.embed_tokens.weight'].size(0)
if loaded_dict_size == len(self.encoder.dictionary) + 1 and '<mask>' not in self.encoder.dictionary:
truncate_emb('encoder.embed_tokens.weight')
truncate_emb('decoder.embed_tokens.weight')
truncate_emb('encoder.output_projection.weight')
truncate_emb('decoder.output_projection.weight')
# When continued pretraining on new set of languages for mbart,
# add extra lang embeddings at the end of embed_tokens.
# Note: newly added languages are assumed to have been added at the end.
if self.args.task == 'multilingual_denoising' and loaded_dict_size < len(self.encoder.dictionary):
logger.info(
"Adding extra language embeddings not found in pretrained model for "\
"continued pretraining of MBART on new set of languages."
)
loaded_mask_token_embedding = state_dict['encoder.embed_tokens.weight'][-1, :]
num_langids_to_add = len(self.encoder.dictionary) - loaded_dict_size
embed_dim = state_dict['encoder.embed_tokens.weight'].size(1)
new_lang_embed_to_add = torch.zeros(num_langids_to_add, embed_dim)
nn.init.normal_(
new_lang_embed_to_add,
mean=0,
std=embed_dim ** -0.5
)
new_lang_embed_to_add = new_lang_embed_to_add.to(
dtype=state_dict['encoder.embed_tokens.weight'].dtype,
)
state_dict['encoder.embed_tokens.weight'] = torch.cat([
state_dict['encoder.embed_tokens.weight'][:loaded_dict_size-1, :],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0)]
)
state_dict['decoder.embed_tokens.weight'] = torch.cat([
state_dict['decoder.embed_tokens.weight'][:loaded_dict_size-1, :],
new_lang_embed_to_add,
loaded_mask_token_embedding.unsqueeze(0)]
)
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, 'classification_heads'):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + 'classification_heads.' + k not in state_dict:
logger.info('Overwriting', prefix + 'classification_heads.' + k)
state_dict[prefix + 'classification_heads.' + k] = v