def load_megatron_model()

in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_vl.py [0:0]


def load_megatron_model(args):
    """load a TPxPPx checkpoint into a TP1PP1 model."""
    os.makedirs(args.save, exist_ok=True)

    model = model_provider()
    args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
    args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
    if args.target_num_layers_per_virtual_pipeline_stage is not None:
        args.num_layers_per_virtual_pipeline_stage = args.target_num_layers_per_virtual_pipeline_stage
        num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
        args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
            args.num_layers_per_virtual_pipeline_stage

    model_path = args.load
    tracker_filename = get_checkpoint_tracker_filename(model_path)
    iteration, release = read_metadata(tracker_filename)

    head_dim = args.hidden_size // args.num_attention_heads
    group_per_split = args.num_query_groups // args.tensor_model_parallel_size
    
    vision_state_dicts = defaultdict(dict)
    state_dict = {}
    mid_state = defaultdict(list)
    if (
        args.tensor_model_parallel_size == 1
        and args.pipeline_model_parallel_size == 1
    ):
        checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, None, None)
        state_dict = torch.load(checkpoint_name, weights_only=False)['model']

    elif (
        args.tensor_model_parallel_size > 1
        and args.pipeline_model_parallel_size == 1
    ):  
        for tp_rank in range(args.tensor_model_parallel_size):
            checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, None, None)
            print(f'load {checkpoint_name}')
            split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
            for k, v in split_state.items():
                if k.startswith('vision_model'):
                    vision_state_dicts[(tp_rank, 0)][k] = v
                else:
                    mid_state[k].append(v)
        for k, v in mid_state.items():
            if 'extra_state' in k:
                continue
            elif not isinstance(v[0], torch.Tensor) or 'norm' in k:
                target_v = v[0]
            elif 'embedding' in k or 'output_layer' in k:
                target_v = torch.cat(v, dim=0)
            elif 'linear_proj' in k or 'linear_fc2' in k:
                target_v = torch.cat(v, dim=1)
            elif 'linear_qkv.weight' in k:
                viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
            elif 'linear_qkv.bias' in k:
                viewed = [x.view(group_per_split, -1) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1)
            elif 'linear_fc1' in k:
                viewed = [x.view(2, -1, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
            else:
                raise ValueError
            state_dict[k] = target_v
    elif (
        args.pipeline_model_parallel_size > 1
    ):  
        ltog, _ = build_layer_id_mapping(args)
        for tp_rank in range(args.tensor_model_parallel_size):
            for pp_rank in range(args.pipeline_model_parallel_size):
                checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank, None, None)
                print(f'load {checkpoint_name}')
                keys = ['model']
                if args.virtual_pipeline_model_parallel_size is not None:
                    keys = [f'model{i}' for i in range(args.virtual_pipeline_model_parallel_size)]
                split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)
                for vpp_id, key in enumerate(keys):
                    for k, v in split_state[key].items():
                        if k.startswith('vision_model'):
                            assert pp_rank == 0
                            vision_state_dicts[(tp_rank, 0)][k] = v
                            continue
                        try:
                            pattern = re.compile(r'\d+')
                            local_id = int(pattern.findall(k)[0])
                            global_id = ltog[(pp_rank, vpp_id, local_id)]
                            tgt = re.sub(r"decoder.layers.\d+", f"decoder.layers.{global_id}", k)
                            mid_state[tgt].append(v)
                        except Exception as e:
                            print(f"Skipping {k} with exception {e}")
                            mid_state[k].append(v)

        for k, v in mid_state.items():
            if 'extra_state' in k:
                continue
            elif not isinstance(v[0], torch.Tensor) or 'norm' in k:
                target_v = v[0]
            elif 'embedding' in k or 'output_layer' in k:
                target_v = torch.cat(v, dim=0)
            elif 'linear_proj' in k or 'linear_fc2' in k:
                target_v = torch.cat(v, dim=1)
            elif 'linear_qkv.weight' in k:
                viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
            elif 'linear_qkv.bias' in k:
                viewed = [x.view(group_per_split, -1) for x in v]
                target_v = torch.cat(viewed, dim=0).view(-1)
            elif 'linear_fc1' in k:
                viewed = [x.view(2, -1, args.hidden_size) for x in v]
                target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
            else:
                raise ValueError
            state_dict[k] = target_v
    else:
        raise ValueError('not support yet')

    load_split_state_dict_to_vision_model(vision_state_dicts, model.vision_model, args)
    _missing, _unexpected = model.load_state_dict(state_dict, strict=False)
    missing = list(filter(lambda k: 'extra_state' not in k and not k.startswith('vision_model'), _missing))
    unexpected = list(filter(lambda k: 'extra_state' not in k and not k.startswith('vision_model'), _unexpected))
    print(f"missing keys: {missing}; unexpected keys: {unexpected}")
    return model