in src/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py [0:0]
def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
mega_original_args = pkl.load(f)
# load the original encoder
original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()
# load its weights
print(
"Original Mega encoder:",
original_mlm.mega.load_state_dict(
torch.load(
os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True
)
),
)
print(
"Original Mega MLM layer:",
original_mlm.mlm_head.load_state_dict(
torch.load(
os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
)
),
)
# create a new config from the old one
hf_config = MegaConfig(
num_hidden_layers=mega_original_args["depth"],
vocab_size=mega_original_args["vocab_size"],
hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
dropout_prob=mega_original_args["mega_args"].dropout,
attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
activation=mega_original_args["mega_args"].activation_fn,
attention_activation=mega_original_args["mega_args"].attention_activation_fn,
bidirectional=mega_original_args["mega_args"].bidirectional,
use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
truncation=mega_original_args["mega_args"].truncation_length,
normalization_type=mega_original_args["mega_args"].normalization_type,
normalize_before_mega=True,
norm_affine=True,
use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
max_positions=mega_original_args["mega_args"].max_source_positions,
nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
# new arguments added for HF implementation
nffn_activation_dropout_prob=0.0,
add_token_type_embeddings=False,
add_lm_hidden_dense_layer=False,
)
hf_mlm = MegaForMaskedLM(hf_config).eval()
# the originl checkpoint just uses nn.Embedding for the word embeddings
# we use a wrapper module for embeddings to add support for positional embeddings
hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight
# modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
# ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
# also renaming previously confusing parameter names
original_state_dict = original_mlm.mega.encoders.state_dict()
updated_keys = {}
for module_name in original_state_dict.keys():
new_module_name = None
# have to handle gamma, beta, and alpha differently due to their use
# in multiple modules within the original repository;
# beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
# the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
if "beta" in module_name:
# EMA sub-layers were always called "move" in the original repo
if "move.beta" in module_name:
new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
elif "mega_layer.beta" in module_name:
new_module_name = module_name.replace("beta", "qk_bias")
else:
new_module_name = module_name.replace("beta", "b_param")
# beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
elif "gamma" in module_name:
if "move.gamma" in module_name:
new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
elif "mega_layer.gamma" in module_name:
new_module_name = module_name.replace("gamma", "qk_weight")
else:
new_module_name = module_name.replace("gamma", "g_param")
# alpha is used in EMA and positional bias; renaming to improve readability
elif "move.alpha" in module_name:
new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
# delta is only used in EMA; renaming to improve readability
elif "move.delta" in module_name:
new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
# omega is only used in EMA; renaming to improve readability
elif "omega" in module_name:
new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")
if new_module_name:
updated_keys[module_name] = new_module_name
if len(updated_keys) != 0:
print(f"Renaming these keys: {updated_keys.keys()}")
else:
print("No need to rename state dict entries")
for old, new in updated_keys.items():
original_state_dict[new] = original_state_dict.pop(old)
# now attempt to load the state dictionary with updated names
# note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))
# load the MLM head weights directly
print(
"HF Mega MLM layer:",
hf_mlm.mlm_head.load_state_dict(
torch.load(
os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
)
),
)
# test on a randomly generated input sequence
input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
input_mask = torch.ones_like(input_ids)
# mask a few tokens to make sure masking is applied appropriately :)
input_mask[:, -10:] = 0
# run forward passes
original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
hf_output = hf_mlm(input_ids, input_mask)[0]
# print shapes and diff
print(f"original output {original_output.shape}")
print(f"hf output {hf_output.shape}")
print(f"max diff: {(original_output - hf_output).max()}") # 0.0
success = torch.allclose(original_output, hf_output, atol=1e-3)
if success:
print("Yay!")
hf_mlm.save_pretrained(output_path)
else:
raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")
if includes_tokenizer:
print("Transferring tokenizer")
tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
tokenizer.save_pretrained(output_path)