in chatlearn/utils/vllm_utils.py [0:0]
def __init__(self, src_names, layer_offset, qwen_version=QwenVersion.v_1):
self.qwen_version = qwen_version
src_prefix = "module.module.language_model"
# configuration for different versions of qwen
if qwen_version == QwenVersion.v_1:
dst_prefix = "model.transformer"
embed_name = "wte"
att_dense_name = ".attn.c_proj."
self.layer_prefix = "h"
mlp_dense_name = ".mlp.c_proj."
final_norm = "ln_f"
elif qwen_version == QwenVersion.v_2:
dst_prefix = "model" if is_vllm_v2() else "model.model"
embed_name = "embed_tokens"
att_dense_name = ".self_attn.o_proj."
self.layer_prefix = "layers"
mlp_dense_name = ".mlp.down_proj."
final_norm = "norm"
else:
raise RuntimeError(f"Unsupported qwen version {qwen_version}, only 1.0 or 2.0 for now.")
# The regex to extract layer names.
self.layer_re = re.compile(rf"{src_prefix}.encoder.layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
self.src_prefix = src_prefix
self.dst_prefix = dst_prefix
self._embedding_sync_map = {
f"{src_prefix}.embedding.word_embeddings.weight": f"{dst_prefix}.{embed_name}.weight",
"module.module.word_embeddings.weight": f"{dst_prefix}.{embed_name}.weight"
}
self._layer_sync_map = {
"attention.attention_layernorm": ".attn.attention_layernorm.",
"attention.dense": ".attn.c_proj.",
"self_attention.dense": att_dense_name,
"mlp.dense_h_to_4h": ".mlp.gate_up_proj.",
"mlp.w1": ".mlp.gate_up_proj.",
"mlp.w2": ".mlp.gate_up_proj.",
"mlp.dense_4h_to_h": mlp_dense_name,
"mlp.dense_layernorm": "mlp.dense_layernorm",
"mlp.router.layer": ".mlp.gate.",
"mlp.experts.dense_h_to_4h": ".mlp.experts.w13_weight",
"mlp.experts.dense_4h_to_h": ".mlp.experts.w2_weight",
"mlp.shared_experts.dense_h_to_4h": ".mlp.shared_expert.gate_up_proj.",
"mlp.shared_experts.dense_4h_to_h": ".mlp.shared_expert.down_proj.",
"mlp.gate": ".mlp.shared_expert_gate."
}
self._final_layer_sync_map = {
f"{src_prefix}.encoder.final_layernorm.bias": f"{dst_prefix}.{final_norm}.bias",
f"{src_prefix}.encoder.final_layernorm.weight": f"{dst_prefix}.{final_norm}.weight",
f"{src_prefix}.output_layer.weight": "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"
}
self._concat_params_dict = {
"modules": ["mlp.w1", "mlp.w2"],
"dim": 0
}
self._to_fix_shared_expert_ordering = {
"modules": ["mlp.shared_experts.dense_h_to_4h"],
"dim": 0
}
self._to_fix_act_ordering_dict = {
"modules": ["mlp.dense_h_to_4h"],
"dim": 0
}
self._to_fix_qkv_ordering_dict = {
"modules": [
"attention.query_key_value",
"self_attention.query_key_value"
],
"layer_re": self.layer_re
}
self._to_allgather_routed_experts_dict = {
"modules": [
"mlp.experts.dense_h_to_4h",
"mlp.experts.dense_4h_to_h",
],
"layer_re": self.layer_re
}
self._to_alltoall_routed_experts_dict = {
"modules": [
"mlp.experts.dense_h_to_4h",
"mlp.experts.dense_4h_to_h",
],
"layer_re": self.layer_re
}
src_names_list = []
for idx, s_name in enumerate(src_names):
if "mlp.w1" in s_name:
src_names_list.append(src_names[idx + 1])
src_names_list.append(s_name)
elif "mlp.w2" in s_name:
continue
else:
src_names_list.append(s_name)
super().__init__(src_names_list, layer_offset)