def load_megatron_model()

in toolkits/model_checkpoints_convertor/llava/hf2mcore_llava.py [0:0]


def load_megatron_model(args):
    """
    Load a TP1PP1 model(full model) from arbitrary tp-pp rank
    """
    os.makedirs(args.save, exist_ok=True)
    os.system("cp -rf " + args.hf_ckpt_path + "/config*.json " + args.save)
    os.system("cp -rf " + args.hf_ckpt_path + "/tokenizer* " + args.save)
    os.system("cp -rf " + args.hf_ckpt_path + "/vocab.json " + args.save)
    os.system("cp -rf " + args.hf_ckpt_path + "/merges.txt " + args.save)

    os.system("cp -rf " + args.hf_ckpt_path + "/config*.json " + args.load)
    os.system("cp -rf " + args.hf_ckpt_path + "/tokenizer* " + args.load)
    os.system("cp -rf " + args.hf_ckpt_path + "/vocab.json " + args.load)
    os.system("cp -rf " + args.hf_ckpt_path + "/merges.txt " + args.load)
    
    model = model_provider().cpu()
    args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
    args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size

    model_path = args.load
    tracker_filename = get_checkpoint_tracker_filename(model_path)
    iteration, release = read_metadata(tracker_filename)

    if args.tensor_model_parallel_size > 1:
        args.sequence_parallel = True

    assert args.num_query_groups >= args.target_tensor_model_parallel_size

    head_dim = args.hidden_size // args.num_attention_heads
    group_per_split = args.num_query_groups // args.target_tensor_model_parallel_size

    state_dict = {}
    mid_state = defaultdict(list)

    if args.tensor_model_parallel_size == 1 and args.pipeline_model_parallel_size == 1:
        checkpoint_name = get_checkpoint_name(
            model_path, iteration, release, None, None, None, None, None
        )
        state_dict = torch.load(checkpoint_name)["model"]
    elif args.tensor_model_parallel_size > 1 and args.pipeline_model_parallel_size == 1:
        for tp_rank in range(args.tensor_model_parallel_size):
            checkpoint_name = get_checkpoint_name(
                model_path, iteration, release, None, tp_rank, None, None, None
            )
            print(f"load {checkpoint_name}")
            split_state = torch.load(checkpoint_name, map_location="cpu")["model"]
            for k, v in split_state.items():
                mid_state[k].append(v)
        for k, v in mid_state.items():

            if not isinstance(v[0], torch.Tensor) or "norm" in k:
                target_v = v[0]
            elif "embedding" in k or "output_layer" in k:
                target_v = torch.cat(v, dim=0)
            elif "linear_proj" in k or "linear_fc2" in k:
                target_v = torch.cat(v, dim=1)
            elif "linear_qkv.weight" in k:
                viewed = [
                    x.view(group_per_split, -1, head_dim, args.hidden_size)
                    for x in v
                ]
                target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
            elif "linear_qkv.bias" in k:
                viewed = [x.view(group_per_split, -1) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1)
            elif "linear_fc1" in k:
                viewed = [x.view(2, -1, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
            else:
                raise ValueError

            state_dict[k] = target_v
    elif args.tensor_model_parallel_size > 1 and args.pipeline_model_parallel_size > 1:
        num_layers = args.num_layers // args.pipeline_model_parallel_size
        layers_to_copy = {}
        for tp_rank in range(args.tensor_model_parallel_size):
            for pp_rank in range(args.pipeline_model_parallel_size):
                layer_offset = pp_rank * num_layers
                for layer in range(num_layers):
                    pp_layer_id = layer + layer_offset
                    layers_to_copy[f"decoder.layers.{layer}"] = pp_layer_id
                checkpoint_name = get_checkpoint_name(
                    model_path, iteration, release, True, tp_rank, pp_rank, None, None
                )
                print(f"load {checkpoint_name}")
                split_state = torch.load(checkpoint_name, map_location="cpu")["model"]
                for k, v in split_state.items():
                    try:
                        pattern = re.compile(r"\d+")
                        res = pattern.findall(k)
                        k = re.sub(
                            r"decoder.layers.\d+",
                            "decoder.layers."
                            + str(layers_to_copy["decoder.layers." + res[0]]),
                            k,
                        )
                        mid_state[k].append(v)
                    except:
                        mid_state[k].append(v)
        for k, v in mid_state.items():
            if not isinstance(v[0], torch.Tensor) or "norm" in k:
                target_v = v[0]
            elif "embedding" in k or "output_layer" in k:
                target_v = torch.cat(v, dim=0)
            elif "linear_proj" in k or "linear_fc2" in k:
                target_v = torch.cat(v, dim=1)
            elif "linear_qkv.weight" in k:
                viewed = [
                    x.view(group_per_split, -1, head_dim, args.hidden_size)
                    for x in v
                ]
                target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
            elif "linear_qkv.bias" in k:
                viewed = [x.view(group_per_split, -1) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1)
            elif "linear_fc1" in k:
                viewed = [x.view(2, -1, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
            else:
                raise ValueError
            state_dict[k] = target_v
    incompat_keys = model.load_state_dict(state_dict, strict=False)

    unexpected_keys = []
    for key in incompat_keys.unexpected_keys:
        if "extra_state" not in key:
            unexpected_keys.append(key)
    assert len(unexpected_keys) == 0, "Unexpected Keys: " + str(unexpected_keys)
    missed_keys = []
    for key in incompat_keys.missing_keys:
        if "extra_state" not in key:
            missed_keys.append(key)
    assert len(missed_keys) == 0, "Missing Keys: " + str(missed_keys)
    return model