def save_mgmodel()

in toolkits/model_checkpoints_convertor/deepseek/hf2mcore_deepseek_v3_moe.py [0:0]


def save_mgmodel(mgmodel, args):
    # tp, etp, ep, pp
    args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
    args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size

    if args.num_experts is not None:
        args.expert_model_parallel_size = args.target_expert_model_parallel_size
        args.expert_tensor_parallel_size = args.target_expert_tensor_parallel_size

    os.makedirs(args.save, exist_ok=True)
    os.system("cp -rf " + args.load + "/*config.json " + args.save)
    os.system("cp -rf " + args.load + "/tokenizer* " + args.save)
    os.system("cp -rf " + args.load + "/*tok* " + args.save)

    tracker_filepath = os.path.join(args.save, 'latest_checkpointed_iteration.txt')
    with open(tracker_filepath, "w") as f:
        f.write("release")

    full_model = mgmodel.state_dict_for_save_checkpoint()
    for k in list(full_model.keys()):
        if 'extra_state' in k:
            # NOTE: since TE 1.14, fp8 metadata will be saved as tensor. 
            # Always drop these values in the MG ckpt to avoid potential issue.
            # This should work fine because fp8 metadata is not supported by HF ckpt.
            full_model[k] = None
        elif full_model[k] is None:
            full_model.pop(k)

    if args.num_experts is not None:
        if args.moe_grouped_gemm == True:
            pattern = r'weight(\d+)'
        else:
            pattern = r'local_experts\.(\d+)\.'
        assert args.num_experts % args.expert_model_parallel_size == 0
        num_local_experts = args.num_experts // args.expert_model_parallel_size if args.num_experts else 0

    if args.target_decoder_first_pipeline_num_layers is not None:
        assert args.pipeline_model_parallel_size > 1, "decoder_first_pipeline_num_layers is only valid when pp_size > 1"
        remained_layers = args.num_layers - args.target_decoder_first_pipeline_num_layers
        remained_stages = args.pipeline_model_parallel_size - 1
        assert remained_layers % remained_stages == 0
        pp_layers_per_stage = [ args.target_decoder_first_pipeline_num_layers] +([remained_layers // remained_stages] * remained_stages)
    else:
        pp_layers_per_stage = [args.num_layers // args.pipeline_model_parallel_size] * args.pipeline_model_parallel_size

    tp_size = args.tensor_model_parallel_size
    etp_size = args.expert_tensor_parallel_size
    for (tp_rank, etp_rank, ep_rank, pp_rank) in generate_rank_group(
        args.tensor_model_parallel_size,
        args.expert_tensor_parallel_size,
        args.expert_model_parallel_size,
        args.pipeline_model_parallel_size
    ):
        model_split = {}
        layer_offset = sum(pp_layers_per_stage[:pp_rank])
        layers_to_copy = {}
        for layer in range(pp_layers_per_stage[pp_rank]):
            pp_layer_id = layer + layer_offset
            layers_to_copy[f"decoder.layers.{pp_layer_id}"] = layer
        checkpoint_name = get_checkpoint_name(
            args.save, 0, True, 
            args.pipeline_model_parallel_size > 1, 
            tp_rank, 
            pp_rank, 
            args.expert_model_parallel_size > 1, 
            ep_rank
        )
        print(f'save model to {checkpoint_name}')
        has_mtp = pp_rank == args.pipeline_model_parallel_size - 1
        for k, v in full_model.items():
            # NOTE: If k not in current pp_rank, skipping
            if check_layer(layers_to_copy, k):
                layer_pattern = re.compile(r'\d+')
                res = layer_pattern.findall(k)
                k = re.sub(r"decoder.layers.\d+", "decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
            elif 'mtp' in k:
                if not has_mtp:
                    continue
            elif not contains(k, ["word_embeddings", "output_layer", "final_layernorm"]):
                continue

            if not isinstance(v, torch.Tensor):
                target_v = v
            elif contains(k, ['linear_q_down_proj', 'linear_kv_down_proj', 'linear_q_proj', 'linear_q_up_proj', 'linear_kv_up_proj', 'linear_q_proj']) and 'norm' not in k:
                target_v = split_column_parallel(v, tp_rank, tp_size)
            elif 'linear_proj' in k:
                target_v = split_row_parallel(v, tp_rank, tp_size)
            elif 'mlp.linear_fc2' in k: # down proj in Dense Layer
                target_v = split_row_parallel(v, tp_rank, tp_size)
            elif 'mlp.linear_fc1' in k and 'norm' not in k: # gate_up proj in Dense Layer
                # Split Gated Column Linear
                seg = args.ffn_hidden_size // args.tensor_model_parallel_size
                viewed = v.view(-1, args.ffn_hidden_size, args.hidden_size)
                target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1, args.hidden_size)
            elif 'experts' in k and 'shared_experts' not in k:
                # NOTE: If k not in current ep_rank, skipping
                expert_rank = int(re.findall(pattern, k)[0])
                if expert_rank // num_local_experts != ep_rank:
                    continue
                expert_local_rank = expert_rank % num_local_experts
                if args.moe_grouped_gemm == True:
                    k = k.replace(f'weight{expert_rank}', f'weight{expert_local_rank}')
                else:
                    k = k.replace(f'local_experts.{expert_rank}', f'local_experts.{expert_local_rank}')
                if 'linear_fc1' in k:
                    viewed = v.view(-1, args.moe_ffn_hidden_size, args.hidden_size)
                    seg = args.moe_ffn_hidden_size // etp_size
                    target_v = viewed[:, seg * etp_rank: seg * (etp_rank + 1), :].reshape(-1, args.hidden_size)
                elif 'linear_fc2' in k:
                    target_v = split_row_parallel(v, etp_rank, etp_size)
                else:
                    raise NotImplementedError()
            elif 'shared_experts' in k and 'gate' not in k:
                # SharedExperts is from MLP, split by tp_rank
                if 'linear_fc1' in k:
                    viewed = v.view(-1, args.moe_shared_expert_intermediate_size, args.hidden_size)
                    seg = args.moe_shared_expert_intermediate_size // tp_size
                    target_v = viewed[:, seg * tp_rank: seg * (tp_rank + 1), :].reshape(-1, args.hidden_size)
                elif 'linear_fc2' in k:
                    target_v = split_row_parallel(v, tp_rank, tp_size)
                else:
                    raise NotImplementedError()
            elif "word_embeddings" in k or "output_layer" in k:
                target_v = split_column_parallel(v, tp_rank, tp_size)
            elif 'eh_proj' in k:
                target_v = split_column_parallel(v, tp_rank, tp_size)
            else:
                target_v = v

            if "embedding.word_embeddings" in k:
                if pp_rank == 0 or (args.mtp_num_layers > 0 and has_mtp):
                    model_split[k] = target_v
            elif "output_layer" in k or "final_layernorm" in k:
                if pp_rank == args.pipeline_model_parallel_size - 1:
                    model_split[k] = target_v
            else:
                model_split[k] = target_v
        save_state_dict(args, [model_split], checkpoint_name)
    print(f'megatron model is save to {args.save}')