in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2.5_vl.py [0:0]
def load_megatron_model(args):
"""load a TPxPPx checkpoint into a TP1PP1 model."""
os.makedirs(args.save, exist_ok=True)
model = model_provider()
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
if args.target_num_layers_per_virtual_pipeline_stage is not None:
args.num_layers_per_virtual_pipeline_stage = args.target_num_layers_per_virtual_pipeline_stage
num_layers_per_pipeline_stage = args.num_layers // args.pipeline_model_parallel_size
args.virtual_pipeline_model_parallel_size = num_layers_per_pipeline_stage // \
args.num_layers_per_virtual_pipeline_stage
model_path = args.load
tracker_filename = get_checkpoint_tracker_filename(model_path)
iteration, release = read_metadata(tracker_filename)
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_query_groups // args.tensor_model_parallel_size
vision_state_dicts = defaultdict(dict)
state_dict = {}
mid_state = defaultdict(list)
if (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, None, None)
state_dict = torch.load(checkpoint_name, weights_only=False)['model']
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size == 1
):
for tp_rank in range(args.tensor_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, None, None)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
if k.startswith('vision_model'):
vision_state_dicts[(tp_rank, 0)][k] = v
else:
mid_state[k].append(v)
for k, v in mid_state.items():
if 'extra_state' in k:
continue
elif not isinstance(v[0], torch.Tensor) or 'norm' in k:
target_v = v[0]
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v
elif (
args.pipeline_model_parallel_size > 1
):
ltog, _ = build_layer_id_mapping(args)
for tp_rank in range(args.tensor_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank, None, None)
print(f'load {checkpoint_name}')
keys = ['model']
if args.virtual_pipeline_model_parallel_size is not None:
keys = [f'model{i}' for i in range(args.virtual_pipeline_model_parallel_size)]
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)
for vpp_id, key in enumerate(keys):
for k, v in split_state[key].items():
if k.startswith('vision_model'):
assert pp_rank == 0
vision_state_dicts[(tp_rank, 0)][k] = v
continue
try:
pattern = re.compile(r'\d+')
local_id = int(pattern.findall(k)[0])
global_id = ltog[(pp_rank, vpp_id, local_id)]
tgt = re.sub(r"decoder.layers.\d+", f"decoder.layers.{global_id}", k)
mid_state[tgt].append(v)
except Exception as e:
print(f"Skipping {k} with exception {e}")
mid_state[k].append(v)
for k, v in mid_state.items():
if 'extra_state' in k:
continue
elif not isinstance(v[0], torch.Tensor) or 'norm' in k:
target_v = v[0]
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v
else:
raise ValueError('not support yet')
load_split_state_dict_to_vision_model(vision_state_dicts, model.vision_model, args)
_missing, _unexpected = model.load_state_dict(state_dict, strict=False)
missing = list(filter(lambda k: 'extra_state' not in k and not k.startswith('vision_model'), _missing))
unexpected = list(filter(lambda k: 'extra_state' not in k and not k.startswith('vision_model'), _unexpected))
print(f"missing keys: {missing}; unexpected keys: {unexpected}")
return model