in optimum/neuron/models/training/modeling_utils.py [0:0]
def _fix_state_dict_key_on_load(key) -> tuple[str, bool]:
# Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
# This rename is logged.
if key.endswith("LayerNorm.beta"):
return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
if key.endswith("LayerNorm.gamma"):
return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
# Rename weight norm parametrizations to match changes across torch versions.
# Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
# This rename is not logged.
if hasattr(nn.utils.parametrizations, "weight_norm"):
if key.endswith("weight_g"):
return key.replace("weight_g", "parametrizations.weight.original0"), True
# ** Difference from original _fix_state_dict_key_on_load **
# If we do not check that `"qkv_proj"` is not in the key this method changes the name of the value weights
# when using GQAQKVColumnParallelLinear.
if key.endswith("weight_v") and "qkv_proj" not in key:
return key.replace("weight_v", "parametrizations.weight.original1"), True
else:
if key.endswith("parametrizations.weight.original0"):
return key.replace("parametrizations.weight.original0", "weight_g"), True
if key.endswith("parametrizations.weight.original1"):
return key.replace("parametrizations.weight.original1", "weight_v"), True
return key, False