in toolkits/model_checkpoints_convertor/llama/hf2mcore_llama3_1.py [0:0]
def check_hf_mg_forward(hfmodel, mgmodel, mgargs):
if mgargs.fp16:
mgmodel = mgmodel.half()
hfmodel = hfmodel.half()
elif mgargs.bf16:
mgmodel = mgmodel.bfloat16()
hfmodel = hfmodel.bfloat16()
hf_hiddens = [{} for _ in range(mgargs.num_layers)]
mg_hiddens = [{} for _ in range(mgargs.num_layers)]
hidden_size = mgargs.hidden_size
vocab_size = mgargs.padded_vocab_size
def print_input_hook(module, args, kwargs, layer_idx, mode):
frame, name = mode.split("-")
if frame == "hf" and "attn_in" in mode:
hf_hiddens[layer_idx][name] = kwargs.get("hidden_states")[0]
elif frame == "hf":
hf_hiddens[layer_idx][name] = args[0].transpose(0, 1)
elif frame == "mg" and "layer" in mode:
mg_hiddens[layer_idx][name] = kwargs.get("hidden_states")
elif frame == "mg" and mode == "mg-attn_in":
mg_hiddens[layer_idx][name] = args[0][:, 0]
elif frame == "mg":
mg_hiddens[layer_idx][name] = args[0]
def print_output_hook(module, args, kwargs, output, layer_idx, mode):
frame, name = mode.split("-")
if mode in ["hf-lmhead"]:
hf_hiddens[layer_idx][name] = output.transpose(0, 1).reshape(-1, vocab_size)
hf_hiddens[layer_idx][name + "_weight"] = module.weight
hf_hiddens[layer_idx][name + "_token"] = output.transpose(0, 1).max(dim=-1)[
1
]
elif mode in ["mg-lmhead"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, vocab_size)
mg_hiddens[layer_idx][name + "_weight"] = module.weight
mg_hiddens[layer_idx][name + "_token"] = output[0].max(dim=-1)[1]
elif mode in ["hf-o_proj_out"]:
hf_hiddens[layer_idx][name] = output
hf_hiddens[layer_idx][name + "_weight"] = module.weight
elif mode in ["mg-o_proj_out"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
mg_hiddens[layer_idx][name + "_weight"] = module.weight
elif mode in ["hf-attn_out"]:
hf_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
elif mode in ["hf-core_attn_out"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
elif mode in ["mg-core_attn_out"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
elif mode in ["mg-attn_out"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
elif mode in ["mg-mlp_out"]:
mg_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
elif mode in ["hf-mlp_out"]:
hf_hiddens[layer_idx][name] = output[0].reshape(-1, hidden_size)
hfmodel.lm_head.register_forward_hook(
partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode="hf-lmhead"),
with_kwargs=True,
)
mgmodel.output_layer.register_forward_hook(
partial(print_output_hook, layer_idx=mgargs.num_layers - 1, mode="mg-lmhead"),
with_kwargs=True,
)
for idx, layer in enumerate(hfmodel.model.layers):
layer.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-layer_in"),
with_kwargs=True,
)
layer.self_attn.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-attn_in"),
with_kwargs=True,
)
layer.self_attn.o_proj.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-o_proj_in"),
with_kwargs=True,
)
layer.self_attn.o_proj.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="hf-o_proj_out"),
with_kwargs=True,
)
layer.self_attn.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="hf-attn_out"),
with_kwargs=True,
)
layer.post_attention_layernorm.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-pre_ln_in"),
with_kwargs=True,
)
layer.mlp.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-mlp_in"),
with_kwargs=True,
)
layer.mlp.down_proj.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="hf-down_in"),
with_kwargs=True,
)
layer.mlp.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="hf-mlp_out"),
with_kwargs=True,
)
for idx, layer in enumerate(mgmodel.decoder.layers):
layer.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-layer_in"),
with_kwargs=True,
)
layer.self_attention.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-attn_in"),
with_kwargs=True,
)
layer.self_attention.linear_proj.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-o_proj_in"),
with_kwargs=True,
)
layer.self_attention.linear_proj.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="mg-o_proj_out"),
with_kwargs=True,
)
layer.self_attention.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="mg-attn_out"),
with_kwargs=True,
)
layer.pre_mlp_layernorm.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-pre_ln_in"),
with_kwargs=True,
)
layer.mlp.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-mlp_in"),
with_kwargs=True,
)
layer.mlp.linear_fc2.register_forward_pre_hook(
partial(print_input_hook, layer_idx=idx, mode="mg-down_in"),
with_kwargs=True,
)
layer.mlp.register_forward_hook(
partial(print_output_hook, layer_idx=idx, mode="mg-mlp_out"),
with_kwargs=True,
)
input_ids = torch.tensor([[1, 2, 3]]).long().cuda()
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
input_ids, -100, True, True, True
)
print(hfmodel)
print(mgmodel)
is_oom = False
with torch.inference_mode():
try:
hfmodel.cuda()
hflogits = hfmodel(input_ids=input_ids).logits
except torch.cuda.OutOfMemoryError:
print("oom for huggingface model forward")
is_oom = True
hfmodel.cpu()
del hfmodel
with torch.inference_mode():
try:
mgmodel.cuda()
mglogits = mgmodel(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
)
except torch.cuda.OutOfMemoryError:
print("oom for megatron model forward")
is_oom = True
del mgmodel
epsilon = 1e-5
for idx, (hfh, mgh) in enumerate(zip(hf_hiddens, mg_hiddens)):
assert len(hfh) == len(mgh)
for k, hfv in hfh.items():
mgv, hfv = mgh[k].cpu(), hfv.cpu()
same_num = (hfv != mgv).sum()
diff_num = ((hfv - mgv) > epsilon).sum()
diff_max = (hfv - mgv).abs().max()
print(
f"layer:{idx}, {k}, diff: {same_num}, diff>{epsilon}:[{diff_num}/{hfv.numel()}] diff_max:{diff_max}"
)
if not is_oom:
same_num = (hflogits != mglogits).sum()
diff_num = ((hflogits - mglogits) > epsilon).sum()
diff_max = (hflogits - mglogits).abs().max()
print(
f"logits: {same_num}, diff>{epsilon}:[{diff_num}/{hflogits.numel()}] diff_max:{diff_max}"
)