in toolkits/model_checkpoints_convertor/qwen/hf2mcore_qwen2_dense_and_moe_gqa.py [0:0]
def load_megatron_model(args):
args.tensor_model_parallel_size = args.target_tensor_model_parallel_size
args.pipeline_model_parallel_size = args.target_pipeline_model_parallel_size
if args.num_experts is not None:
args.expert_model_parallel_size = args.target_expert_model_parallel_size
if args.tensor_model_parallel_size >1:
args.sequence_parallel = True
assert args.num_query_groups >= args.target_tensor_model_parallel_size
os.makedirs(args.save, exist_ok=True)
os.system("cp -rf " + args.hf_ckpt_path + "/config*.json " + args.save)
os.system("cp -rf " + args.hf_ckpt_path + "/generation_config.json " + args.save)
os.system("cp -rf " + args.hf_ckpt_path+ "/tokenizer* " + args.save)
os.system("cp -rf " + args.hf_ckpt_path + "/vocab.json " + args.save)
os.system("cp -rf " + args.hf_ckpt_path + "/merges.txt " + args.save)
os.system("cp -rf " + args.hf_ckpt_path + "/config*.json " + args.load)
os.system("cp -rf " + args.hf_ckpt_path + "/generation_config.json " + args.load)
os.system("cp -rf " + args.hf_ckpt_path+ "/tokenizer* " + args.load)
os.system("cp -rf " + args.hf_ckpt_path + "/vocab.json " + args.load)
os.system("cp -rf " + args.hf_ckpt_path + "/merges.txt " + args.load)
model = model_provider()
model_path = args.load
tracker_filename = get_checkpoint_tracker_filename(model_path)
iteration, release = read_metadata(tracker_filename)
head_dim = args.hidden_size // args.num_attention_heads
group_per_split = args.num_query_groups // args.tensor_model_parallel_size
if args.num_experts is not None:
num_local_experts = args.num_experts // args.expert_model_parallel_size
state_dict = {}
mid_state = defaultdict(list)
if (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, None, None)
state_dict = torch.load(checkpoint_name, weights_only=False)['model']
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size == 1
and args.num_experts is None
):
for tp_rank in range(args.tensor_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, None, None)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'norm' in k:
target_v = v[0]
elif 'extra_state' in k:
target_v = None
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size > 1
and args.num_experts is None
):
num_layers = args.num_layers // args.pipeline_model_parallel_size
layers_to_copy = {}
for tp_rank in range(args.tensor_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
layer_offset = pp_rank * num_layers
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy[f"decoder.layers.{layer}"] = pp_layer_id
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank, None, None)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
try:
pattern = re.compile(r'\d+')
res = pattern.findall(k)
k = re.sub(r"decoder.layers.\d+", "decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
mid_state[k].append(v)
except:
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'norm' in k:
target_v = v[0]
elif 'extra_state' in k:
target_v = None
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
else:
raise ValueError
state_dict[k] = target_v
elif (
args.tensor_model_parallel_size == 1
and args.pipeline_model_parallel_size == 1
and args.expert_model_parallel_size >1
and args.num_experts % args.expert_model_parallel_size == 0
):
for ep_rank in range(args.expert_model_parallel_size):
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, None, None, True, ep_rank)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
if 'local_experts' in k:
expert_local_rank = name_to_expert_rank(k)
expert_rank = expert_local_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{expert_local_rank}', f'local_experts.{expert_rank}')
state_dict[k] = v
elif (
args.tensor_model_parallel_size > 1
and args.pipeline_model_parallel_size == 1
and args.expert_model_parallel_size > 1
and args.num_experts % args.expert_model_parallel_size == 0
):
for tp_rank in range(args.tensor_model_parallel_size):
for ep_rank in range(args.expert_model_parallel_size):
if args.expert_model_parallel_size >1:
checkpoint_name = get_checkpoint_name(model_path, iteration,release, None, tp_rank, None, True, ep_rank)
elif args.expert_model_parallel_size ==1:
checkpoint_name = get_checkpoint_name(model_path, iteration, release, None, tp_rank, None, False)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
if 'local_experts' in k and 'norm' not in k:
local_expert_rank = name_to_expert_rank(k)
expert_rank = local_expert_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{local_expert_rank}', f'local_experts.{expert_rank}')
mid_state[k].append(v)
elif ep_rank == 0:
mid_state[k].append(v)
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'router' in k or 'gate' in k:
target_v = v[0]
elif 'extra_state' in k:
target_v = None
elif 'embedding' in k or 'output_layer' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
elif 'input_layernorm' in k:
target_v = v[0]
elif 'kv_layernorm' in k:
target_v = v[0]
elif 'pre_mlp_layernorm' in k:
target_v = v[0]
else:
print(f"Missing {k}")
raise ValueError
state_dict[k] = target_v
elif (
args.tensor_model_parallel_size >= 1
and args.pipeline_model_parallel_size > 1
and args.expert_model_parallel_size >= 1
and ( args.num_experts is None or args.num_experts % args.expert_model_parallel_size == 0 )
):
num_layers = args.num_layers // args.pipeline_model_parallel_size
layers_to_copy = {}
for tp_rank in range(args.tensor_model_parallel_size):
for ep_rank in range(args.expert_model_parallel_size):
for pp_rank in range(args.pipeline_model_parallel_size):
layer_offset = pp_rank * num_layers
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy[f"decoder.layers.{layer}"] = pp_layer_id
if args.expert_model_parallel_size > 1:
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank, True,
ep_rank)
elif args.expert_model_parallel_size == 1:
checkpoint_name = get_checkpoint_name(model_path, iteration, release, True, tp_rank, pp_rank,
False)
print(f'load {checkpoint_name}')
split_state = torch.load(checkpoint_name, map_location="cpu", weights_only=False)['model']
for k, v in split_state.items():
try:
if 'local_experts' in k:
local_expert_rank = name_to_expert_rank(k)
expert_rank = local_expert_rank + num_local_experts * ep_rank
k = k.replace(f'local_experts.{local_expert_rank}', f'local_experts.{expert_rank}')
pattern = re.compile(r'\d+')
res = pattern.findall(k)
tgt = re.sub(r"decoder.layers.\d+","decoder.layers." + str(layers_to_copy["decoder.layers." + res[0]]), k)
if 'linear_proj' in k or 'linear_q_proj' in k or 'linear_kv_up_proj' in k or 'decoder.layers.0.mlp.linear_fc2' in k or \
'decoder.layers.0.mlp.linear_fc1' in k or 'shared_experts.linear_fc1' in k or 'shared_experts.linear_fc2' in k:
if ep_rank ==0:
mid_state[tgt].append(v)
else:
mid_state[tgt].append(v)
except:
print(f"Skipping {k}")
if "word_embeddings" in k:
if ep_rank ==0 and pp_rank == 0:
mid_state[k].append(v)
elif "output_layer" in k or "final_layernorm" in k:
if ep_rank ==0 and pp_rank == args.pipeline_model_parallel_size - 1:
mid_state[k].append(v)
else:
raise ValueError("Something is wrong!")
for k, v in mid_state.items():
if not isinstance(v[0], torch.Tensor) or 'router' in k or 'gate' in k:
target_v = v[0]
elif 'extra_state' in k:
target_v = None
elif 'word_embeddings' in k or 'output_layer' in k or 'final_layernorm' in k:
target_v = torch.cat(v, dim=0)
elif 'linear_proj' in k or 'linear_fc2' in k:
target_v = torch.cat(v, dim=1)
elif 'linear_qkv.weight' in k:
viewed = [x.view(group_per_split, -1, head_dim, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1, args.hidden_size)
elif 'linear_qkv.bias' in k:
viewed = [x.view(group_per_split, -1) for x in v]
target_v = torch.cat(viewed, dim=0).view(-1)
elif 'linear_fc1' in k and "layer_norm_weight" not in k:
viewed = [x.view(2, -1, args.hidden_size) for x in v]
target_v = torch.cat(viewed, dim=1).view(-1, args.hidden_size)
elif 'input_layernorm' in k:
target_v = v[0]
elif 'layer_norm_weight' in k:
target_v = v[0]
elif 'pre_mlp_layernorm' in k:
target_v = v[0]
else:
print(f"Missing {k}")
raise ValueError
state_dict[k] = target_v
else:
raise ValueError('not support yet')
model.load_state_dict(state_dict, strict=False)
return model