chatlearn/synchronizer/base.py (81 lines of code) (raw):
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""base"""
import torch
from chatlearn.utils import future
from chatlearn.utils import utils
class BaseSync:
"""Base synchronizer"""
def __init__(self, src_model, dst_model):
self.src_model = src_model
self.dst_model = dst_model
self.is_parameter_changed = False
self.concat_params_dict = None
def get_or_cache(self, actor, func_name, *args, **kwargs):
def inner_func(*args, **kwargs):
return future.get(getattr(getattr(actor, func_name), 'remote')(*args, **kwargs))
cached_name = str(actor) + "_" + func_name
if hasattr(self, cached_name):
cached = getattr(self, cached_name)
else:
cached = {}
setattr(self, cached_name, cached)
return utils.get_or_cache(cached, actor, inner_func, *args, **kwargs)
def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names): # pylint: disable=unused-argument
"""
map layer name from src model to dst model
"""
return src_names, dst_names
def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # pylint: disable=unused-argument
"""
allgather routed expert params
"""
return params_to_sync, False
def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument
"""
alltoall routed expert params
"""
return params_to_sync, False
def transform_parameters(self, params_to_sync_list):
"""
transform parameters, e.g. concat, fix ordering
"""
return params_to_sync_list
def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False):
"""
:meta private:
"""
param_data_shape = param_data.shape
# Regroup these tensors into different tp slices.
# Output: [tp_slice_0, tp_slice_1, ...]
# Comment:
# src -> dst: [w, h * tp_size] -> tp_size * [w, h]
# 'self_attention.dense' in QWen and LLama2 legacy
# 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model
# 'mlp.linear_fc2' in LLama2 mcore model
# 'mlp.shared_experts.dense_4h_to_h in QWen-MoE model
# src -> dst: [w * tp_size, h] -> tp_size * [w, h]
# 'mlp.dense_h_to_4h' in QWen and LLama2 legacy
# 'mlp.linear_fc1' in LLama2 mcore model
# 'mlp.w1' in QWen model only for vLLM backend
# 'mlp.shared_experts.dense_h_to_4h in QWen-MoE model
if (
"self_attention.dense" in name
or "mlp.dense_4h_to_h" in name
or "mlp.linear_fc2" in name
or "mlp.shared_experts.dense_4h_to_h" in name
):
param_data_list = []
col_offset = param_data_shape[1] // tp_division
for idx in range(tp_division):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:,start:end])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
if (
"mlp.dense_h_to_4h" in name
or "mlp.linear_fc1" in name
or ("mlp.w1" in name and self.concat_params_dict is not None)
or "mlp.shared_experts.dense_h_to_4h" in name
):
param_data_list = []
row_offset = param_data_shape[0] // tp_division // 2
for idx in range(tp_division):
w1_start = idx * row_offset
w1_end = w1_start + row_offset
w2_start = (idx + tp_division) * row_offset
w2_end = w2_start + row_offset
param_data_list.append(
torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0))
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
# src -> dst: src_tp_size * [e, h, w] -> dst_tp_size * [e, h, w // tp_division]
# 'mlp.experts.dense_4h_to_h' in QWen-MoE model when training with QWen+vLLM
# src -> dst: src_tp_size * [e, w, h] -> dst_tp_size * [e, w // tp_division, h]
# 'mlp.experts.dense_h_to_4h in QWen-MoE model when training with QWen+vLLM
if regroup_routed_experts:
if "mlp.experts.dense_4h_to_h" in name:
param_data_list = []
height_offset = param_data_shape[2] // tp_division
for height_idx in range(tp_division):
height_start = height_idx * height_offset
height_end = height_start + height_offset
param_data_list.append(param_data[:, :, height_start:height_end])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
elif "mlp.experts.dense_h_to_4h" in name:
param_data_list = []
param_data = param_data.reshape(param_data_shape[0] * 2, -1, param_data_shape[2])
col_offset = param_data.shape[1] // tp_division
for idx in range(tp_division):
start = idx * col_offset
end = start + col_offset
param_data_list.append(param_data[:, start:end, :])
param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape)
del param_data_list
return param_data