chatlearn/synchronizer/megatron_vllm.py (356 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""megatron to vllm synchronizer"""
from abc import abstractmethod
import operator
from functools import reduce
import ray.util.collective as col
import torch
from chatlearn.utils.constant import QwenVersion
from chatlearn.utils.utils import get_use_legacy_models
from chatlearn.utils.vllm_utils import fix_qwen_query_key_value_ordering
from chatlearn.utils.vllm_utils import split_attn_state
from chatlearn.utils.vllm_utils import Megatron2LlamaSyncMap, Megatron2QWenSyncMap, MCore2LlamaSyncMap
from chatlearn.utils.megatron_import_memory_helper import MegatronVersion, get_megatron_version
from .base import BaseSync
class MegatronVllmSync(BaseSync):
"""Megatron to vllm sync"""
def __init__(self, src_model, dst_model):
super().__init__(src_model, dst_model)
self.src_module_args = src_model.module_args
self.dst_module_args = dst_model.module_args
self.is_parameter_changed = True
@abstractmethod
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
"""
:meta private:
"""
def _validate(self, sync_map):
if sync_map.concat_params_dict is not None:
if isinstance(sync_map.concat_params_dict, dict):
assert "modules" in sync_map.concat_params_dict
assert "dim" in sync_map.concat_params_dict
assert isinstance(sync_map.concat_params_dict["modules"], list)
else:
raise RuntimeError(f"Expect concat_params_dict in {self} to be a dict or None, while {sync_map.concat_params_dict}.")
if sync_map.to_fix_act_ordering_dict is not None:
if isinstance(sync_map.to_fix_act_ordering_dict, dict):
assert "modules" in sync_map.to_fix_act_ordering_dict
assert "dim" in sync_map.to_fix_act_ordering_dict
assert isinstance(sync_map.to_fix_act_ordering_dict["modules"], list)
else:
raise RuntimeError(f"Expect to_fix_act_ordering_dict in {self} to be a dict or None, while {sync_map.to_fix_act_ordering_dict}.")
if sync_map.to_fix_qkv_ordering_dict is not None:
if isinstance(sync_map.to_fix_qkv_ordering_dict, dict):
assert "modules" in sync_map.to_fix_qkv_ordering_dict
assert "layer_re" in sync_map.to_fix_qkv_ordering_dict
assert isinstance(sync_map.to_fix_qkv_ordering_dict["modules"], list)
else:
raise RuntimeError(f"Expect to_fix_qkv_ordering_dict in {self} to be a dict or None, while {sync_map.to_fix_qkv_ordering_dict}.")
def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names):
src_pipe_layer_offset = self.get_or_cache(send_actor, "get_pipeline_stage_layer_offset")
dst_pipe_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset")
self.sync_map = self.map_src_to_dst(src_names, src_pipe_layer_offset+dst_pipe_layer_offset)
self._validate(self.sync_map)
self.concat_params_dict = self.sync_map.concat_params_dict
return self.sync_map.src_names, self.sync_map.dst_names
def concat_params(self, params_to_sync_list):
if self.sync_map.concat_params_dict is None:
return params_to_sync_list
concat_modules_list = self.sync_map.concat_params_dict["modules"]
concat_dim = self.sync_map.concat_params_dict["dim"]
params_to_sync_list_new = []
concat = []
for name, params in params_to_sync_list:
if any(ele in name for ele in concat_modules_list):
concat.append(params)
if len(concat) == len(concat_modules_list):
params = torch.cat(concat, dim=concat_dim)
params_to_sync_list_new.append((name, params))
concat = []
else:
params_to_sync_list_new.append((name, params))
return params_to_sync_list_new
def fix_qkv_ordering(self, params_to_sync_list):
to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
if to_fix_qkv_ordering_dict is None:
return params_to_sync_list
layer_re = self.sync_map.to_fix_qkv_ordering_dict["layer_re"]
to_fix_modules_list = to_fix_qkv_ordering_dict["modules"]
for i, (name, params_to_sync) in enumerate(params_to_sync_list):
m = layer_re.match(name)
if m is None:
continue
op_name = m.group(2)
if op_name in to_fix_modules_list:
checkpoint_version = 3.0
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // tp_size
hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]
if self._to_fix_qkv_ordering_func is split_attn_state:
_num_query_groups = self.src_module_args.args_dict["num_query_groups"]//tp_size \
if self.src_module_args.args_dict["group_query_attention"] else heads
params_to_sync = self._to_fix_qkv_ordering_func(
params_to_sync, heads, _num_query_groups, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"])
params_to_sync_list[i] = (name, params_to_sync)
else:
input_shape = params_to_sync.size()
shape = (heads, hidden_size_per_head, 3) + input_shape[1:]
division = reduce(operator.mul, shape, 1)
num_elements = params_to_sync.numel()
if num_elements == division:
# model with gqa dont need to fix qkv ordering.
weight_or_bias = m.group(3)
params_to_sync = self._to_fix_qkv_ordering_func(
params_to_sync, checkpoint_version, 3, heads, hidden_size_per_head
)
if weight_or_bias == "weight":
params_to_sync = params_to_sync.contiguous()
params_to_sync_list[i] = (name, params_to_sync)
return params_to_sync_list
def fix_act_ordering(self, params_to_sync_list):
if self.sync_map.to_fix_act_ordering_dict is None:
return params_to_sync_list
fix_dim = self.sync_map.to_fix_act_ordering_dict["dim"]
to_fix_act_ordering_list = self.sync_map.to_fix_act_ordering_dict["modules"]
for i, (name, params_to_sync) in enumerate(params_to_sync_list):
if any([ele in name for ele in to_fix_act_ordering_list]): # pylint: disable=use-a-generator
val = params_to_sync
offset = val.shape[0] // 2
w1 = val[:offset,:]
w2 = val[offset:,:]
params_to_sync = torch.cat([w2, w1], dim=fix_dim)
params_to_sync_list[i] = (name, params_to_sync)
return params_to_sync_list
def fix_shared_expert_ordering(self, params_to_sync_list):
if self.sync_map.to_fix_shared_expert_ordering is None:
return params_to_sync_list
fix_dim = self.sync_map.to_fix_shared_expert_ordering["dim"]
to_fix_shared_expert_ordering_list = self.sync_map.to_fix_shared_expert_ordering["modules"]
for i, (name, params_to_sync) in enumerate(params_to_sync_list):
if any([ele in name for ele in to_fix_shared_expert_ordering_list]): # pylint: disable=use-a-generator
w1, w2 = params_to_sync.chunk(2, dim=0)
params_to_sync = torch.cat([w2, w1], dim=fix_dim).contiguous()
params_to_sync_list[i] = (name, params_to_sync)
return params_to_sync_list
def allgather_routed_experts_from_hep(self, name, params_to_sync, group_name, tp_rank):
"""
This function is applicable for synchronizing parameters from QWen with HEP enabled
to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
Thus, the function will all-gather across EP size * TP size routed experts and slice
them to TP size partitions.
"""
if self.sync_map._to_allgather_routed_experts_dict is None:
return params_to_sync, False
to_allgather_routed_experts_dict = self.sync_map._to_allgather_routed_experts_dict
layer_re = to_allgather_routed_experts_dict["layer_re"]
to_regroup_modules_list = to_allgather_routed_experts_dict["modules"]
m = layer_re.match(name)
if m is None:
return params_to_sync, False
op_name = m.group(2)
if op_name in to_regroup_modules_list:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
hep_size = tp_size * ep_size
moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
local_num_experts = moe_num_experts // hep_size
hidden_size = self.src_module_args.args_dict["hidden_size"]
output_tensor_list = [
torch.empty(size=params_to_sync.shape, dtype=params_to_sync.dtype, device=params_to_sync.device)
for _ in range(hep_size)
]
col.allgather(output_tensor_list, params_to_sync, group_name)
del params_to_sync
val_list = []
if "dense_h_to_4h" in op_name:
# w13_weight
while output_tensor_list:
params = output_tensor_list.pop(0)
# regroup among difference tp slices
params = params.view((moe_num_experts, -1, hidden_size)).contiguous()
params = params.reshape((local_num_experts * 2, -1, hidden_size))
params = params.chunk(tp_size, dim=1)[tp_rank]
# reorder w1 and w3
params = params.reshape(params.shape[0] // 2, -1, hidden_size)
params_right, params_left = params.chunk(2, dim=1)
del params
params = torch.cat([params_left, params_right], dim=1)
del params_left
del params_right
val_list.append(params)
params_to_sync = torch.cat(val_list, dim=0).contiguous()
else:
# w2_weight
while output_tensor_list:
params = output_tensor_list.pop(0)
params = params.reshape((local_num_experts, -1, hidden_size))
chunked_params = params.chunk(tp_size, dim=1)[tp_rank].contiguous()
del params
val_list.append(chunked_params)
params_to_sync = torch.cat(val_list, dim=0).transpose(1, 2).contiguous()
del val_list
return params_to_sync, True
else:
return params_to_sync, False
def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # pylint: disable=unused-argument
megatron_version = get_megatron_version()
if megatron_version == MegatronVersion.V4:
return self.allgather_routed_experts_from_hep(name, params_to_sync, group_name, tp_rank)
else:
raise NotImplementedError(
"ChatLearn does not support all-gathering routed experts for Megatron-LM, but supports QWen with HEP enabled. "
"Please export `QWEN_VERSION` as `qwen_moe_v1`."
)
def alltoall_routed_experts_from_hep(self, name, params_to_sync, comm_group):
"""
This function is applicable for synchronizing parameters from QWen with HEP enabled
to vLLM. In HEP, routed experts are split into a total number of EP size * TP size.
Thus, the function will all-to-all across EP size * TP size routed experts.
"""
if self.sync_map._to_alltoall_routed_experts_dict is None:
return params_to_sync, False
to_alltoall_routed_experts_dict = self.sync_map._to_alltoall_routed_experts_dict
layer_re = to_alltoall_routed_experts_dict["layer_re"]
to_regroup_modules_list = to_alltoall_routed_experts_dict["modules"]
m = layer_re.match(name)
if m is None:
return params_to_sync, False
op_name = m.group(2)
if op_name in to_regroup_modules_list:
tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
ep_size = self.src_module_args.args_dict["moe_expert_model_parallel_size"]
hep_size = tp_size * ep_size
moe_num_experts = self.src_module_args.args_dict["moe_num_experts"]
local_num_experts = moe_num_experts // hep_size
hidden_size = self.src_module_args.args_dict["hidden_size"]
if "dense_h_to_4h" in op_name:
# w13_weight
# regroup among difference tp slices
param = params_to_sync.view((moe_num_experts, -1, hidden_size))
param = param.reshape((local_num_experts * 2, -1, hidden_size))
params = list(param.chunk(hep_size, dim=1))
# reorder w1 and w3
params_list = []
while params:
param = params.pop(0)
param = param.reshape(param.shape[0] // 2, -1, hidden_size)
param_right, param_left = param.chunk(2, dim=1)
del param
param = torch.cat([param_left, param_right], dim=1)
del param_left
del param_right
params_list.append(param)
del params_to_sync
output = [
torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
for i in range(hep_size)
]
torch.distributed.all_to_all(output, params_list, group=comm_group)
del params_list
params_to_sync = torch.cat(output, dim=0).contiguous()
del output
else:
# w2_weight
param = params_to_sync.view((local_num_experts, -1, hidden_size))
params = list(param.chunk(hep_size, dim=1))
params_list = [ele.contiguous() for ele in params]
del param
del params
del params_to_sync
output = [
torch.empty(size=params_list[i].shape, dtype=params_list[i].dtype, device=params_list[i].device)
for i in range(hep_size)
]
torch.distributed.all_to_all(output, params_list, group=comm_group)
del params_list
params_to_sync = torch.cat(output, dim=0).transpose(1, 2).contiguous()
del output
return params_to_sync, True
else:
return params_to_sync, False
def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument
megatron_version = get_megatron_version()
if megatron_version == MegatronVersion.V4:
return self.alltoall_routed_experts_from_hep(name, params_to_sync, comm_group)
else:
raise NotImplementedError(
"ChatLearn does not support all-to-all routed experts for Megatron-LM, but supports QWen with HEP enabled. "
"Please export `QWEN_VERSION` as `qwen_moe_v1`."
)
def transform_parameters(self, params_to_sync_list):
"""
transform parameters, e.g. concat, fix ordering
"""
params_to_sync_list = self.concat_params(params_to_sync_list)
params_to_sync_list = self.fix_act_ordering(params_to_sync_list)
params_to_sync_list = self.fix_qkv_ordering(params_to_sync_list)
params_to_sync_list = self.fix_shared_expert_ordering(params_to_sync_list)
return params_to_sync_list
def regroup_qkv_tp_slices(self, name, param_data, tp_division):
param_data_shape = param_data.shape
# Regroup qkv tensors into different tp slices only for inference model which enables vLLM backend.
to_fix_qkv_ordering_dict = self.sync_map.to_fix_qkv_ordering_dict
# pylint: disable=too-many-nested-blocks
if "attention.query_key_value" in name or \
"self_attention.query_key_value" in name or \
"self_attention.linear_qkv" in name:
src_tp_size = self.src_module_args.args_dict["tensor_model_parallel_size"]
dst_tp_size = self.dst_module_args.args_dict["tensor_model_parallel_size"]
heads = self.src_module_args.args_dict["num_attention_heads"] // src_tp_size
hidden_size_per_head = self.src_module_args.args_dict["hidden_size"] // self.src_module_args.args_dict["num_attention_heads"]
param_shape = (3, heads, hidden_size_per_head) + param_data_shape[1:]
division = reduce(operator.mul, param_shape, 1)
num_elements = param_data.numel()
if num_elements == division:
if to_fix_qkv_ordering_dict is not None:
param_data = param_data.view(param_shape)
param_data_list = []
head_offset = heads // tp_division
for idx in range(tp_division):
start = idx * head_offset
end = start + head_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).view(param_data_shape)
del param_data_list
else:
if self.src_module_args.args_dict["group_query_attention"]:
num_query_groups = self.src_module_args.args_dict["num_query_groups"]
assert num_query_groups == self.dst_module_args.args_dict["num_query_groups"], (
f"num_query_groups of src model ({num_query_groups}) must be equal to num_query_groups of "
f"dst model ({self.dst_moduel_args.args_dict['num_query_groups']}). Please double-check your config."
)
src_num_query_groups_per_replica = num_query_groups // src_tp_size
if dst_tp_size >= num_query_groups:
num_dst_kv_head_replicas = dst_tp_size // num_query_groups
else:
num_dst_kv_head_replicas = 1
else:
src_num_query_groups_per_replica = heads
num_dst_kv_head_replicas = 1
if to_fix_qkv_ordering_dict is not None or src_num_query_groups_per_replica == 1:
if len(param_data_shape) == 1:
param_data = param_data.view((heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head))
else:
param_data = param_data.view(
(heads + 2 * src_num_query_groups_per_replica, hidden_size_per_head, self.src_module_args.args_dict["hidden_size"]))
param_data_list = []
head_offset = heads // tp_division
for idx in range(tp_division):
q_start = idx * head_offset
q_end = q_start + head_offset
if num_dst_kv_head_replicas == 1:
if src_num_query_groups_per_replica > tp_division:
assert src_num_query_groups_per_replica % tp_division == 0, (
f"num_query_groups per replica of src model ({src_num_query_groups_per_replica}) "
f"must be divisible by tp_division ({tp_division}). Please double-check your config."
)
kv_offset = src_num_query_groups_per_replica // tp_division
else:
kv_offset = 1
k_start = (heads + idx) if src_num_query_groups_per_replica // tp_division else heads
k_end = k_start + kv_offset
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + kv_offset
else:
k_start = heads + idx // num_dst_kv_head_replicas
k_end = k_start + 1
v_start = k_start + src_num_query_groups_per_replica
v_end = v_start + 1
q_proj = param_data[q_start:q_end].contiguous()
k_proj = param_data[k_start:k_end].contiguous()
v_proj = param_data[v_start:v_end].contiguous()
qkv_proj = torch.cat([q_proj, k_proj, v_proj], dim=0)
if len(param_data_shape) == 1:
qkv_proj = qkv_proj.reshape(-1).contiguous()
else:
qkv_proj = qkv_proj.reshape(-1, self.src_module_args.args_dict["hidden_size"]).contiguous()
param_data_list.append(qkv_proj)
param_data = torch.concat(param_data_list, dim=0)
del param_data_list
return param_data
def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False):
param_data = self.regroup_qkv_tp_slices(name, param_data, tp_division)
return super().regroup_params_to_sync(name, param_data, tp_division, regroup_routed_experts)
class MegatronVllmQWenSync(MegatronVllmSync):
"""qwen"""
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
"""
:meta private:
"""
self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
return Megatron2QWenSyncMap(src_names, src_pipe_layer_offset, QwenVersion.v_1.value)
class MegatronVllmQWen2Sync(MegatronVllmSync):
"""qwen2"""
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
self._to_fix_qkv_ordering_func = split_attn_state
return Megatron2QWenSyncMap(src_names, src_pipe_layer_offset, QwenVersion.v_2.value)
class MegatronVllmLlamaSync(MegatronVllmSync):
"""llama"""
def map_src_to_dst(self, src_names, src_pipe_layer_offset):
use_legacy_models = get_use_legacy_models(self.src_model.module_args.args_dict)
sync_map_cls = Megatron2LlamaSyncMap if use_legacy_models else MCore2LlamaSyncMap
self._to_fix_qkv_ordering_func = fix_qwen_query_key_value_ordering
return sync_map_cls(src_names, src_pipe_layer_offset)