def convert_checkpoint_to_huggingface()

in src/transformers/models/deprecated/mega/convert_mega_original_pytorch_checkpoint_to_pytorch.py [0:0]


def convert_checkpoint_to_huggingface(pretrained_checkpoint_path, output_path, includes_tokenizer):
    with open(os.path.join(pretrained_checkpoint_path, "model_args.pkl"), "rb") as f:
        mega_original_args = pkl.load(f)

    # load the original encoder
    original_mlm = OriginalMegaForMaskedLM(**mega_original_args).eval()

    # load its weights
    print(
        "Original Mega encoder:",
        original_mlm.mega.load_state_dict(
            torch.load(
                os.path.join(pretrained_checkpoint_path, "encoder_weights.pt"), map_location="cpu", weights_only=True
            )
        ),
    )
    print(
        "Original Mega MLM layer:",
        original_mlm.mlm_head.load_state_dict(
            torch.load(
                os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
            )
        ),
    )

    # create a new config from the old one
    hf_config = MegaConfig(
        num_hidden_layers=mega_original_args["depth"],
        vocab_size=mega_original_args["vocab_size"],
        hidden_size=mega_original_args["mega_args"].encoder_embed_dim,
        shared_representation_size=mega_original_args["mega_args"].encoder_z_dim,
        intermediate_size=mega_original_args["mega_args"].encoder_hidden_dim,
        ema_projection_size=mega_original_args["mega_args"].encoder_n_dim,
        dropout_prob=mega_original_args["mega_args"].dropout,
        attention_probs_dropout_prob=mega_original_args["mega_args"].attention_dropout,
        hidden_dropout_prob=mega_original_args["mega_args"].hidden_dropout,
        activation=mega_original_args["mega_args"].activation_fn,
        attention_activation=mega_original_args["mega_args"].attention_activation_fn,
        bidirectional=mega_original_args["mega_args"].bidirectional,
        use_chunking=mega_original_args["mega_args"].encoder_chunk_size > 0,
        chunk_size=mega_original_args["mega_args"].encoder_chunk_size,
        truncation=mega_original_args["mega_args"].truncation_length,
        normalization_type=mega_original_args["mega_args"].normalization_type,
        normalize_before_mega=True,
        norm_affine=True,
        use_feature_dropout=mega_original_args["mega_args"].feature_dropout,
        relative_positional_bias=mega_original_args["mega_args"].rel_pos_bias,
        max_positions=mega_original_args["mega_args"].max_source_positions,
        nffn_hidden_size=mega_original_args["mega_args"].encoder_ffn_embed_dim,
        normalize_before_ffn=mega_original_args["mega_args"].normalize_before,
        # new arguments added for HF implementation
        nffn_activation_dropout_prob=0.0,
        add_token_type_embeddings=False,
        add_lm_hidden_dense_layer=False,
    )

    hf_mlm = MegaForMaskedLM(hf_config).eval()

    # the originl checkpoint just uses nn.Embedding for the word embeddings
    # we use a wrapper module for embeddings to add support for positional embeddings
    hf_mlm.mega.embedding_layer.word_embeddings.weight = original_mlm.mega.embedding_layer.weight

    # modify the state dictionary of the original checkpoint to account for naming issues in the Hugging Face
    # ecosystem -- any names containing "beta" or "gamma" aren't safe to use and are renamed upon _load_pretrained,
    # also renaming previously confusing parameter names
    original_state_dict = original_mlm.mega.encoders.state_dict()
    updated_keys = {}
    for module_name in original_state_dict.keys():
        new_module_name = None
        # have to handle gamma, beta, and alpha differently due to their use
        # in multiple modules within the original repository;
        # beta is used in EMA, MovingAverageGatedAttention, and RotaryRelativePositionalBias, and must be renamed due to flax/tf weights
        # the EMA sublayer was renamed from "move" to "ema_gate" for readability, so that is also done here
        if "beta" in module_name:
            # EMA sub-layers were always called "move" in the original repo
            if "move.beta" in module_name:
                new_module_name = module_name.replace("move.beta", "ema_gate.ema_expansion_matrix")
            elif "mega_layer.beta" in module_name:
                new_module_name = module_name.replace("beta", "qk_bias")
            else:
                new_module_name = module_name.replace("beta", "b_param")
        # beta is used in EMA and MovingAverageGatedAttention, and must be renamed due to flax/tf weights
        elif "gamma" in module_name:
            if "move.gamma" in module_name:
                new_module_name = module_name.replace("move.gamma", "ema_gate.kernel_projection_matrix")
            elif "mega_layer.gamma" in module_name:
                new_module_name = module_name.replace("gamma", "qk_weight")
            else:
                new_module_name = module_name.replace("gamma", "g_param")
        # alpha is used in EMA and positional bias; renaming to improve readability
        elif "move.alpha" in module_name:
            new_module_name = module_name.replace("move.alpha", "ema_gate.decay_factor")
        # delta is only used in EMA; renaming to improve readability
        elif "move.delta" in module_name:
            new_module_name = module_name.replace("move.delta", "ema_gate.damping_factor")
        # omega is only used in EMA; renaming to improve readability
        elif "omega" in module_name:
            new_module_name = module_name.replace("move.omega", "ema_gate.residual_weight")

        if new_module_name:
            updated_keys[module_name] = new_module_name

    if len(updated_keys) != 0:
        print(f"Renaming these keys: {updated_keys.keys()}")
    else:
        print("No need to rename state dict entries")
    for old, new in updated_keys.items():
        original_state_dict[new] = original_state_dict.pop(old)

    # now attempt to load the state dictionary with updated names
    # note that we now call it `mega.layers` instead of `mega.encoders` due to hugging face style
    print("HF Mega encoder:", hf_mlm.mega.layers.load_state_dict(original_state_dict))

    # load the MLM head weights directly
    print(
        "HF Mega MLM layer:",
        hf_mlm.mlm_head.load_state_dict(
            torch.load(
                os.path.join(pretrained_checkpoint_path, "mlm_head_weights.pt"), map_location="cpu", weights_only=True
            )
        ),
    )

    # test on a randomly generated input sequence
    input_ids = torch.randint(0, hf_config.vocab_size, size=(4, 256))
    input_mask = torch.ones_like(input_ids)
    # mask a few tokens to make sure masking is applied appropriately :)
    input_mask[:, -10:] = 0

    # run forward passes
    original_output = original_mlm(input_ids, input_mask, batch_first=True, ignore_mask_value=0)
    hf_output = hf_mlm(input_ids, input_mask)[0]

    # print shapes and diff
    print(f"original output {original_output.shape}")
    print(f"hf output {hf_output.shape}")
    print(f"max diff: {(original_output - hf_output).max()}")  # 0.0
    success = torch.allclose(original_output, hf_output, atol=1e-3)

    if success:
        print("Yay!")
        hf_mlm.save_pretrained(output_path)
    else:
        raise RuntimeError(f"Something's broken :(\nOriginal:\n{original_output}\n\nHF\n{hf_output}\n{hf_mlm}")

    if includes_tokenizer:
        print("Transferring tokenizer")
        tokenizer = AutoTokenizer.from_pretrained(pretrained_checkpoint_path)
        tokenizer.save_pretrained(output_path)