in toolkits/model_checkpoints_convertor/mistral/hf2mcore_mixtral.py [0:0]
def convert_checkpoint_from_transformers_to_megatron(args):
assert args.world_size == args.target_expert_model_parallel_size * args.target_tensor_model_parallel_size * args.target_pipeline_model_parallel_size
os.makedirs(args.save_path, exist_ok=True)
# Saving config and tokenzier files
os.system("cp -rf " + args.load_path + "/*.json " + args.save_path)
os.system("cp -rf " + args.load_path + "/tokeniz* " + args.save_path)
# Saving the tracker file
tracker_filepath = os.path.join(args.save_path, "latest_checkpointed_iteration.txt")
with open(tracker_filepath, "w") as f:
f.write("release")
# create `release` dir in args.load_path
release_dir = os.path.join(args.save_path, "release")
os.makedirs(release_dir, exist_ok=True)
config = AutoConfig.from_pretrained(args.load_path)
# megatron args
megatron_args = {
"orig_vocab_size": config.vocab_size,
"hidden_size": config.hidden_size,
"num_layers": config.num_hidden_layers,
"num_attention_heads": config.num_attention_heads,
"tensor_model_parallel_size": args.target_tensor_model_parallel_size,
"pipeline_model_parallel_size": args.target_pipeline_model_parallel_size
}
margs = types.SimpleNamespace()
for k, v in megatron_args.items():
setattr(margs, k, v)
state_dict = AutoModelForCausalLM.from_pretrained(args.load_path).state_dict()
internal_state_dict = {}
for layer_id in range(config.num_hidden_layers):
q_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.q_proj.weight']
k_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.k_proj.weight']
v_weight = state_dict['model.layers.' + str(layer_id) + '.self_attn.v_proj.weight']
internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.query.weight'] = q_weight
internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.key_value.weight'] = torch.cat(
(k_weight, v_weight))
internal_state_dict['transformer.layers.' + str(layer_id) + '.self_attn.dense.weight'] = \
state_dict['model.layers.' + str(layer_id) + '.self_attn.o_proj.weight']
internal_state_dict['transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.gate.wg.weight'] = state_dict[
'model.layers.' + str(layer_id) + '.block_sparse_moe.gate.weight']
for expert_id in range(config.num_local_experts):
internal_state_dict[
'transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(
expert_id) + '.dense_h_to_4h_1.weight'] = \
state_dict[
'model.layers.' + str(layer_id) + '.block_sparse_moe.experts.' + str(expert_id) + '.w1.weight']
internal_state_dict[
'transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(
expert_id) + '.dense_h_to_4h_2.weight'] = \
state_dict[
'model.layers.' + str(layer_id) + '.block_sparse_moe.experts.' + str(expert_id) + '.w3.weight']
internal_state_dict[
'transformer.layers.' + str(layer_id) + '.mlp.megatron_moe.experts.megatron_experts.' + str(
expert_id) + '.dense_4h_to_h.weight'] = state_dict[
'model.layers.' + str(layer_id) + '.block_sparse_moe.experts.' + str(expert_id) + '.w2.weight']
internal_state_dict['transformer.layers.' + str(layer_id) + '.input_layernorm.weight'] = state_dict[
'model.layers.' + str(layer_id) + '.input_layernorm.weight']
internal_state_dict['transformer.layers.' + str(layer_id) + '.post_attention_layernorm.weight'] = state_dict[
'model.layers.' + str(layer_id) + '.post_attention_layernorm.weight']
internal_state_dict["transformer.word_embeddings.weight"] = state_dict['model.embed_tokens.weight']
internal_state_dict["transformer.final_layernorm.weight"] = state_dict['model.norm.weight']
internal_state_dict["transformer.lm_head.weight"] = state_dict['lm_head.weight']
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append(OrderedDict())
num_query_group = config.num_key_value_heads
output_group_state_dict = []
for i in range(num_query_group):
output_group_state_dict.append({})
if args.target_params_dtype == "fp16":
dtype = torch.float16
elif args.target_params_dtype == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32
# Embedding layer
print("converting embedding layer")
word_embedding = internal_state_dict["transformer.word_embeddings.weight"].to(dtype)
out_word_embed = torch.chunk(word_embedding, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
word_emb_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
word_emb_dict["embedding.word_embeddings.weight"] = out_word_embed[i]
print("converting output layer")
lm_head = internal_state_dict["transformer.lm_head.weight"].to(dtype)
out_lm_head = torch.chunk(lm_head, args.target_tensor_model_parallel_size, dim=0)
for i in range(args.target_tensor_model_parallel_size):
lm_head_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
lm_head_dict["output_layer.weight"] = out_lm_head[i]
print("converting transformer layers")
if config.num_hidden_layers % args.target_pipeline_model_parallel_size != 0:
raise ValueError(
f"Number of layers ({config.num_hidden_layers}) must be divisible by number of pipeline parallelism"
f" ({args.target_pipeline_model_parallel_size})"
)
num_layers = config.num_hidden_layers // args.target_pipeline_model_parallel_size
layer_re = re.compile("transformer.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
hidden_size = config.hidden_size
num_groups = config.num_key_value_heads
num_heads = config.num_attention_heads
hidden_size_per_head = config.hidden_size // config.num_attention_heads
for pp_rank in range(args.target_pipeline_model_parallel_size):
layer_offset = pp_rank * num_layers
if pp_rank > 0:
output_state_dict = []
for i in range(args.target_tensor_model_parallel_size):
output_state_dict.append({})
output_group_state_dict = []
for i in range(num_query_group):
output_group_state_dict.append({})
for layer in range(num_layers):
pp_layer_id = layer + layer_offset
layers_to_copy = [
layer_name
for layer_name in internal_state_dict.keys()
if layer_name.startswith(f"transformer.layers.{pp_layer_id}.")
]
for layer_name in layers_to_copy:
m = layer_re.match(layer_name)
# Stop if that's not a layer
if m is None:
break
# The index of the layer.
_ = int(m.group(1))
# The name of the operation.
op_name = m.group(2)
# Is it a weight or a bias?
weight_or_bias = m.group(3)
params = internal_state_dict[layer_name].to(dtype)
# handle layernorm
extra_state_name = None
if op_name.startswith("input_layernorm") and weight_or_bias == "weight":
out_name = "self_attention.linear_qkv"
layer_name = f"layers.{layer}.{out_name}.layer_norm_weight"
extra_state_name = f"layers.{layer}.{out_name}._extra_state"
elif op_name.startswith("post_attention_layernorm") and weight_or_bias == "weight":
out_name = "pre_mlp_layernorm"
layer_name = f"layers.{layer}.{out_name}.weight"
extra_state_name = f"layers.{layer}.{out_name}._extra_state"
# handle attention K, V, Q weights
elif op_name.startswith("self_attn.query") and weight_or_bias == "weight":
# transformers stores D X (3*D) but Megatron-LM expects (3*D) X D.
params = transformers_to_megatron_fix_query_key_value_ordering(
params,
3.0,
1,
num_heads,
hidden_size_per_head,
)
layer_name = f"layers.{layer}.{op_name}.{weight_or_bias}"
elif op_name.startswith("self_attn.key_value") and weight_or_bias == "weight":
# transformers stores D X (3*D) but Megatron-LM expects (3*D) X D.
params = transformers_to_megatron_fix_query_key_value_ordering(
params,
3.0,
2,
num_groups,
hidden_size_per_head,
)
layer_name = f"layers.{layer}.{op_name}.{weight_or_bias}"
# handle attention and mlp weights
elif weight_or_bias == "weight":
out_name = internal_to_output_mapping.get(op_name, None)
if out_name is None:
continue
layer_name = f"layers.{layer}.{out_name}.{weight_or_bias}"
if out_name not in ['linear_fc1_1', 'mlp.router', 'linear_fc1_2']:
extra_state_name = f"layers.{layer}.{out_name}._extra_state"
# skip
else:
continue
if op_name + "." + weight_or_bias in tensor_parallel_params:
dim = 1 if op_name + "." + weight_or_bias in column_split_tensor_parallel_params else 0
params = torch.chunk(params, args.target_tensor_model_parallel_size, dim=dim)
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
params_dict["decoder." + layer_name] = (
params[i].clone() if (
op_name + "." + weight_or_bias in tensor_parallel_params) else params.clone()
)
if extra_state_name is not None:
if 'linear_fc1_' not in extra_state_name:
params_dict["decoder." + extra_state_name] = None
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
for expert_id in range(config.num_local_experts):
dense_h_to_4h_1_name = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert_id}.linear_fc1_1.weight'
dense_h_to_4h_1_weight = params_dict[dense_h_to_4h_1_name]
del params_dict[dense_h_to_4h_1_name]
dense_h_to_4h_2_name = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert_id}.linear_fc1_2.weight'
dense_h_to_4h_2_weight = params_dict[dense_h_to_4h_2_name]
del params_dict[dense_h_to_4h_2_name]
dense_h_to_4h_name = f'decoder.layers.{layer}.mlp.experts.local_experts.{expert_id}.linear_fc1.weight'
params_dict[dense_h_to_4h_name] = \
torch.cat([dense_h_to_4h_1_weight, dense_h_to_4h_2_weight], dim=0)
self_attn_query_name = f"decoder.layers.{layer}.self_attn.query.weight"
query_weight = params_dict[self_attn_query_name]
del params_dict[self_attn_query_name]
self_attn_kv_name = f"decoder.layers.{layer}.self_attn.key_value.weight"
kv_weight = params_dict[self_attn_kv_name]
del params_dict[self_attn_kv_name]
# torch.Size([8 512, 4096])
group_query_weight = query_weight.view(num_groups // args.target_tensor_model_parallel_size,
num_heads // num_groups * hidden_size_per_head, hidden_size)
# torch.Size(8, 256, 4096])
group_kv_weight = kv_weight.view(num_groups // args.target_tensor_model_parallel_size,
2 * hidden_size_per_head, hidden_size)
group_qkv_weight = torch.cat([group_query_weight, group_kv_weight], dim=1)
params_dict["decoder." + f"layers.{layer}.self_attention.linear_qkv.weight"] = \
group_qkv_weight.view(-1, hidden_size)
params_dict["decoder." + f"layers.{layer}.self_attention.linear_qkv._extra_state"] = None
if pp_rank == args.target_pipeline_model_parallel_size - 1:
# handle final layernorm
for weight_or_bias in ["weight"]:
params = internal_state_dict[f"transformer.final_layernorm.{weight_or_bias}"].to(dtype)
layer_name = "decoder." + f"final_layernorm.{weight_or_bias}"
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
params_dict[layer_name] = params.clone()
params_dict["decoder.final_layernorm._extra_state"] = None
# add the embedding
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
params_dict["embedding.word_embeddings.weight"] = out_word_embed[i].clone()
# add the LM head
for i in range(args.target_tensor_model_parallel_size):
params_dict = get_element_from_dict_by_path(output_state_dict[i], "model")
params_dict["output_layer.weight"] = out_lm_head[i].clone()
num_ep_groups = args.world_size // args.target_tensor_model_parallel_size // args.target_pipeline_model_parallel_size
experts_ids = [x for x in range(config.num_local_experts)]
chunks = [experts_ids[x:x + config.num_local_experts // num_ep_groups] for x in
range(0, len(experts_ids), config.num_local_experts // num_ep_groups)]
expert_group_mapping = {}
for idx, chunk in enumerate(chunks):
for ele in chunk:
expert_group_mapping[ele] = idx
expert_local_mapping = {}
for chunk in chunks:
for idx, ele in enumerate(chunk):
expert_local_mapping[ele] = idx
# saving the state dict as per the tp_rank and pp_rank
for tp_rank in range(args.target_tensor_model_parallel_size):
current_keys = list(output_state_dict[tp_rank]['model'].keys())
ep_state_dict = []
for i in range(args.target_expert_model_parallel_size):
ep_state_dict.append({})
for key in current_keys:
if "local_experts" in key:
keywords = key.split(".")
eid = int(keywords[6])
expert_group_id = expert_group_mapping[eid]
local_expert_id = expert_local_mapping[eid]
keywords[6] = str(local_expert_id)
ep_state_dict[expert_group_id][".".join(keywords)] = (
output_state_dict[tp_rank]['model'][key].clone()
if hasattr(output_state_dict[tp_rank]['model'][key], 'clone')
else output_state_dict[tp_rank]['model'][key]
)
output_state_dict[tp_rank]['model'].pop(key)
for ep_rank in range(args.target_expert_model_parallel_size):
checkpoint_dir = get_checkpoint_sub_dir_name(tp_rank, pp_rank, args.target_pipeline_model_parallel_size,
ep_rank, args.target_expert_model_parallel_size)
save_dir = os.path.join(release_dir, checkpoint_dir)
os.makedirs(save_dir, exist_ok=True)
checkpoint_name = "model_optim_rng.pt"
checkpoint_path = os.path.join(save_dir, checkpoint_name)
output_state_dict[tp_rank]['model'].update(ep_state_dict[ep_rank])
save_state_dict(args, [output_state_dict[tp_rank]['model']], checkpoint_path, save_args=False)