def check_hf_mg_forward()

in toolkits/model_checkpoints_convertor/llama/hf2mcore_llama3_1.py [0:0]


def check_hf_mg_forward(hfmodel, mgmodel, mgargs):
    if mgargs.fp16:
        mgmodel = mgmodel.half()
        hfmodel = hfmodel.half()
    elif mgargs.bf16:
        mgmodel = mgmodel.bfloat16()
        hfmodel = hfmodel.bfloat16()
    hf_hiddens = [{} for _ in range(mgargs.num_layers)]
    mg_hiddens = [{} for _ in range(mgargs.num_layers)]

    hidden_size = mgargs.hidden_size
    vocab_size = mgargs.padded_vocab_size

    def print_input_hook(module, args, kwargs, layer_idx, mode):
        frame, name = mode.split("-")
        if frame == "hf" and "attn_in" in mode:
            hf_hiddens[layer_idx][name] = kwargs.get("hidden_states")[0]
        elif frame == "hf":
            hf_hiddens[layer_idx][name] = args[0].transpose(0, 1)
        elif frame == "mg" and "layer" in mode:
            mg_hiddens[layer_idx][name] = kwargs.get("hidden_states")
        elif frame == "mg" and mode == "mg-attn_in":
            mg_hiddens[layer_idx][name] = args[0][:, 0]
        elif frame == "mg":
            mg_hiddens[layer_idx][name] = args[0]

    def print_output_hook(module, args, kwargs, output, layer_idx, mode):
        frame, name = mode.split("-")
        if mode in ["hf-lmhead"]:
            hf_hiddens[layer_idx][name] = output.transpose(0, 1).reshape(-1, vocab_size)
            hf_hiddens[layer_idx][name + "_weight"] = module.weight
            hf_hiddens[layer_idx][name + "_token"] = output.transpose(0, 1).max(dim=-1)[
                1
            ]
        elif mode in ["mg-lmhead"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, vocab_size)
            mg_hiddens[layer_idx][name + "_weight"] = module.weight
            mg_hiddens[layer_idx][name + "_token"] = output[0].max(dim=-1)[1]
        elif mode in ["hf-o_proj_out"]:
            hf_hiddens[layer_idx][name] = output
            hf_hiddens[layer_idx][name + "_weight"] = module.weight
        elif mode in ["mg-o_proj_out"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
            mg_hiddens[layer_idx][name + "_weight"] = module.weight
        elif mode in ["hf-attn_out"]:
            hf_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
        elif mode in ["hf-core_attn_out"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
        elif mode in ["mg-core_attn_out"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
        elif mode in ["mg-attn_out"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
        elif mode in ["mg-mlp_out"]:
            mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
        elif mode in ["hf-mlp_out"]:
            hf_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)

    hfmodel.lm_head.register_forward_hook(
        partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode="hf-lmhead"),
        with_kwargs=True,
    )

    mgmodel.output_layer.register_forward_hook(
        partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode="mg-lmhead"),
        with_kwargs=True,
    )

    for idx, layer in enumerate(hfmodel.model.layers):

        layer.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-layer_in"),
            with_kwargs=True,
        )

        layer.self_attn.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-attn_in"),
            with_kwargs=True,
        )

        layer.self_attn.o_proj.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-o_proj_in"),
            with_kwargs=True,
        )

        layer.self_attn.o_proj.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="hf-o_proj_out"),
            with_kwargs=True,
        )

        layer.self_attn.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="hf-attn_out"),
            with_kwargs=True,
        )

        layer.post_attention_layernorm.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-pre_ln_in"),
            with_kwargs=True,
        )

        layer.mlp.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-mlp_in"),
            with_kwargs=True,
        )

        layer.mlp.down_proj.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="hf-down_in"),
            with_kwargs=True,
        )

        layer.mlp.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="hf-mlp_out"),
            with_kwargs=True,
        )

    for idx, layer in enumerate(mgmodel.decoder.layers):

        layer.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-layer_in"),
            with_kwargs=True,
        )

        layer.self_attention.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-attn_in"),
            with_kwargs=True,
        )

        layer.self_attention.linear_proj.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-o_proj_in"),
            with_kwargs=True,
        )

        layer.self_attention.linear_proj.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="mg-o_proj_out"),
            with_kwargs=True,
        )

        layer.self_attention.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="mg-attn_out"),
            with_kwargs=True,
        )

        layer.pre_mlp_layernorm.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-pre_ln_in"),
            with_kwargs=True,
        )

        layer.mlp.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-mlp_in"),
            with_kwargs=True,
        )

        layer.mlp.linear_fc2.register_forward_pre_hook(
            partial(print_input_hook, layer_idx=idx, mode="mg-down_in"),
            with_kwargs=True,
        )

        layer.mlp.register_forward_hook(
            partial(print_output_hook, layer_idx=idx, mode="mg-mlp_out"),
            with_kwargs=True,
        )

    input_ids = torch.tensor([[1, 2, 3]]).long().cuda()
    attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
        input_ids, -100, True, True, True
    )
    print(hfmodel)
    print(mgmodel)
    is_oom = False
    with torch.inference_mode():
        try:
            hfmodel.cuda()
            hflogits = hfmodel(input_ids=input_ids).logits
        except torch.cuda.OutOfMemoryError:
            print("oom for huggingface model forward")
            is_oom = True
        hfmodel.cpu()
        del hfmodel

    with torch.inference_mode():
        try:
            mgmodel.cuda()
            mglogits = mgmodel(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
            )
        except torch.cuda.OutOfMemoryError:
            print("oom for megatron model forward")
            is_oom = True
        del mgmodel

    epsilon = 1e-5
    for idx, (hfh, mgh) in enumerate(zip(hf_hiddens, mg_hiddens)):
        assert len(hfh) == len(mgh)
        for k, hfv in hfh.items():
            mgv, hfv = mgh[k].cpu(), hfv.cpu()
            same_num = (hfv != mgv).sum()
            diff_num = ((hfv - mgv) > epsilon).sum()
            diff_max = (hfv - mgv).abs().max()
            print(
                f"layer:{idx}, {k}, diff: {same_num}, diff>{epsilon}:[{diff_num}/{hfv.numel()}] diff_max:{diff_max}"
            )

    if not is_oom:
        same_num = (hflogits != mglogits).sum()
        diff_num = ((hflogits - mglogits) > epsilon).sum()
        diff_max = (hflogits - mglogits).abs().max()
        print(
            f"logits: {same_num}, diff>{epsilon}:[{diff_num}/{hflogits.numel()}] diff_max:{diff_max}"
        )