in fairseq/models/roberta/model.py [0:0]
def upgrade_state_dict_named(self, state_dict, name):
prefix = name + "." if name != "" else ""
# rename decoder -> encoder before upgrading children modules
for k in list(state_dict.keys()):
if k.startswith(prefix + "decoder"):
new_k = prefix + "encoder" + k[len(prefix + "decoder") :]
state_dict[new_k] = state_dict[k]
del state_dict[k]
# rename emb_layer_norm -> layernorm_embedding
for k in list(state_dict.keys()):
if ".emb_layer_norm." in k:
new_k = k.replace(".emb_layer_norm.", ".layernorm_embedding.")
state_dict[new_k] = state_dict[k]
del state_dict[k]
# upgrade children modules
super().upgrade_state_dict_named(state_dict, name)
# Handle new classification heads present in the state dict.
current_head_names = (
[]
if not hasattr(self, "classification_heads")
else self.classification_heads.keys()
)
keys_to_delete = []
for k in state_dict.keys():
if not k.startswith(prefix + "classification_heads."):
continue
head_name = k[len(prefix + "classification_heads.") :].split(".")[0]
num_classes = state_dict[
prefix + "classification_heads." + head_name + ".out_proj.weight"
].size(0)
inner_dim = state_dict[
prefix + "classification_heads." + head_name + ".dense.weight"
].size(0)
if getattr(self.args, "load_checkpoint_heads", False):
if head_name not in current_head_names:
self.register_classification_head(head_name, num_classes, inner_dim)
else:
if head_name not in current_head_names:
logger.warning(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}".format(head_name, k)
)
keys_to_delete.append(k)
elif (
num_classes
!= self.classification_heads[head_name].out_proj.out_features
or inner_dim
!= self.classification_heads[head_name].dense.out_features
):
logger.warning(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}".format(
head_name, k
)
)
keys_to_delete.append(k)
for k in keys_to_delete:
del state_dict[k]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if hasattr(self, "classification_heads"):
cur_state = self.classification_heads.state_dict()
for k, v in cur_state.items():
if prefix + "classification_heads." + k not in state_dict:
logger.info("Overwriting " + prefix + "classification_heads." + k)
state_dict[prefix + "classification_heads." + k] = v
# adapt data2vec models
if "encoder._ema" in state_dict and "encoder.lm_head.weight" not in state_dict:
lm_state = self.encoder.lm_head.state_dict()
for k, v in lm_state.items():
state_dict["encoder.lm_head." + k] = v
for k in list(state_dict.keys()):
if k.startswith("encoder.regression_head") or k == "encoder._ema":
del state_dict[k]