chatlearn/synchronizer/base.py (81 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """base""" import torch from chatlearn.utils import future from chatlearn.utils import utils class BaseSync: """Base synchronizer""" def __init__(self, src_model, dst_model): self.src_model = src_model self.dst_model = dst_model self.is_parameter_changed = False self.concat_params_dict = None def get_or_cache(self, actor, func_name, *args, **kwargs): def inner_func(*args, **kwargs): return future.get(getattr(getattr(actor, func_name), 'remote')(*args, **kwargs)) cached_name = str(actor) + "_" + func_name if hasattr(self, cached_name): cached = getattr(self, cached_name) else: cached = {} setattr(self, cached_name, cached) return utils.get_or_cache(cached, actor, inner_func, *args, **kwargs) def map_name_from_src_to_dst(self, send_actor, recv_actor, src_names, dst_names): # pylint: disable=unused-argument """ map layer name from src model to dst model """ return src_names, dst_names def allgather_routed_experts(self, name, params_to_sync, group_name, tp_rank): # pylint: disable=unused-argument """ allgather routed expert params """ return params_to_sync, False def alltoall_routed_experts(self, name, params_to_sync, comm_group): # pylint: disable=unused-argument """ alltoall routed expert params """ return params_to_sync, False def transform_parameters(self, params_to_sync_list): """ transform parameters, e.g. concat, fix ordering """ return params_to_sync_list def regroup_params_to_sync(self, name, param_data, tp_division, regroup_routed_experts=False): """ :meta private: """ param_data_shape = param_data.shape # Regroup these tensors into different tp slices. # Output: [tp_slice_0, tp_slice_1, ...] # Comment: # src -> dst: [w, h * tp_size] -> tp_size * [w, h] # 'self_attention.dense' in QWen and LLama2 legacy # 'mlp.dense_4h_to_h' in QWen and LLama2 legacy model # 'mlp.linear_fc2' in LLama2 mcore model # 'mlp.shared_experts.dense_4h_to_h in QWen-MoE model # src -> dst: [w * tp_size, h] -> tp_size * [w, h] # 'mlp.dense_h_to_4h' in QWen and LLama2 legacy # 'mlp.linear_fc1' in LLama2 mcore model # 'mlp.w1' in QWen model only for vLLM backend # 'mlp.shared_experts.dense_h_to_4h in QWen-MoE model if ( "self_attention.dense" in name or "mlp.dense_4h_to_h" in name or "mlp.linear_fc2" in name or "mlp.shared_experts.dense_4h_to_h" in name ): param_data_list = [] col_offset = param_data_shape[1] // tp_division for idx in range(tp_division): start = idx * col_offset end = start + col_offset param_data_list.append(param_data[:,start:end]) param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape) del param_data_list if ( "mlp.dense_h_to_4h" in name or "mlp.linear_fc1" in name or ("mlp.w1" in name and self.concat_params_dict is not None) or "mlp.shared_experts.dense_h_to_4h" in name ): param_data_list = [] row_offset = param_data_shape[0] // tp_division // 2 for idx in range(tp_division): w1_start = idx * row_offset w1_end = w1_start + row_offset w2_start = (idx + tp_division) * row_offset w2_end = w2_start + row_offset param_data_list.append( torch.concat([param_data[w1_start:w1_end,:], param_data[w2_start:w2_end,:]], dim=0)) param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape) del param_data_list # src -> dst: src_tp_size * [e, h, w] -> dst_tp_size * [e, h, w // tp_division] # 'mlp.experts.dense_4h_to_h' in QWen-MoE model when training with QWen+vLLM # src -> dst: src_tp_size * [e, w, h] -> dst_tp_size * [e, w // tp_division, h] # 'mlp.experts.dense_h_to_4h in QWen-MoE model when training with QWen+vLLM if regroup_routed_experts: if "mlp.experts.dense_4h_to_h" in name: param_data_list = [] height_offset = param_data_shape[2] // tp_division for height_idx in range(tp_division): height_start = height_idx * height_offset height_end = height_start + height_offset param_data_list.append(param_data[:, :, height_start:height_end]) param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape) del param_data_list elif "mlp.experts.dense_h_to_4h" in name: param_data_list = [] param_data = param_data.reshape(param_data_shape[0] * 2, -1, param_data_shape[2]) col_offset = param_data.shape[1] // tp_division for idx in range(tp_division): start = idx * col_offset end = start + col_offset param_data_list.append(param_data[:, start:end, :]) param_data = torch.concat(param_data_list, dim=0).contiguous().view(param_data_shape) del param_data_list return param_data