chatlearn/synchronizer/parameter_sync.py (1,612 lines of code) (raw):

# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Sync parameters""" import concurrent.futures import traceback import time from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from itertools import cycle from typing import List, Dict from queue import PriorityQueue import torch from tqdm import tqdm from chatlearn.launcher.initialize import patch_ray from chatlearn.utils import future from chatlearn.utils import utils from chatlearn.utils.constant import LORA_WEIGHT_PREFIX from chatlearn.utils.constant import PARAM_SYNC_COMM_TYPE from chatlearn.utils.constant import ROUTED_EXPERT_REGROUPING_COMM_TYPE from chatlearn.utils.global_vars import get_args from chatlearn.utils.logger import logger from chatlearn.utils.utils import execute_in_parallel from chatlearn.utils.timer import Timers from chatlearn.synchronizer.scheduler import CollectiveTask, parallel_execute_collective_tasks from . import get_synchronizer patch_ray() class ParameterSyncGroup: """ParameterSyncGroup""" def __init__(self, src_model, dst_model, group_name, frequency, error_signal): self.src_model = src_model self.dst_model = dst_model self.synchronizer = get_synchronizer(src_model, dst_model) self.group_name = group_name self.error_signal = error_signal self.send_recv_actor_mappings = defaultdict(list) self.recv_send_actor_mappings = defaultdict(list) self.send_recv_actor_mappings_stage2 = defaultdict(list) self.recv_send_actor_mappings_stage2 = defaultdict(list) self.actor2rank = {} self.actor2model = {} self._debug = get_args().runtime_args.debug self._num_src_pipeline_stage = None self._num_dst_pipeline_stage = None self._num_src_expert_parallel = None self._num_dst_expert_parallel = None self._num_src_tensor_parallel = None self._num_dst_tensor_parallel = None self._send_recv_param_names = {} self._actor2pipe = {} self._actor2tp = {} self._actor2ep = {} self._actor2dp = {} self._comm_type = get_args().runtime_args.param_sync_comm_type if src_model.colocate_with(dst_model) and self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: if self.num_src_tensor_parallel % 2 == 1 and self.num_dst_tensor_parallel % 2 == 1: logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead") self._comm_type = PARAM_SYNC_COMM_TYPE.P2P self.concurrent_comm = get_args().runtime_args.concurrent_comm self._enable_lora = self.src_model.module_args.lora.enable_lora # sync every n episodes, n = 0 for no param sync self._frequency = frequency self._free_sync_collective_group = get_args().runtime_args.free_sync_collective_group self._is_collective_group_created = True self.collective_groups = [] self.groups2actors = {} # group_name -> []actors self.src_dp_size = future.get(self.src_model.replicas[0].all_actors[0].get_data_parallel_size.remote()) self.send_actors_to_regroup_routed_experts = None self._comm_type_to_regroup_routed_experts = get_args().runtime_args.routed_expert_regrouping_comm_type assert self._comm_type_to_regroup_routed_experts in \ [ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER, ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL], \ f"Only support 'allgather' or 'alltoall' for routed expert regrouping, while {self._comm_type_to_regroup_routed_experts}" if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL: if self.num_dst_tensor_parallel * self.num_dst_expert_parallel != self.num_src_tensor_parallel * self.num_src_expert_parallel: logger.info("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.") self._comm_type_to_regroup_routed_experts = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER logger.info(f"Set ROUTED_EXPERT_REGROUPING_COMM_TYPE = {self._comm_type_to_regroup_routed_experts}.") self.sorted_send_actors = None self.sorted_send_actors_stage2 = None self.actor2synchronizer = {} self.setup_collective_group() self.setup_rank_mapping() self.timers = Timers() def get_group_name(self, actors): return f"{self.group_name}_" + "_".join(str(self.actor2rank[actor]) for actor in actors) @property def frequency(self): return self._frequency def get_or_cache(self, actor, func_name, *args, **kwargs): def inner_func(*args, **kwargs): return future.get(getattr(getattr(actor, func_name), 'remote')(*args, **kwargs)) cached_name = str(actor) + "_" + func_name if hasattr(self, cached_name): cached = getattr(self, cached_name) else: cached = {} setattr(self, cached_name, cached) return utils.get_or_cache(cached, actor, inner_func, *args, **kwargs) def is_same_gpu(self, src_actor, dst_actor): src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") src_address = self.get_or_cache(src_actor, "get_address") dst_address = self.get_or_cache(dst_actor, "get_address") return src_gpu == dst_gpu and src_address == dst_address @property def num_src_pipeline_stage(self): if self._num_src_pipeline_stage is None: self._num_src_pipeline_stage = future.get(self.src_model.replicas[0].all_actors[0].pipeline_model_parallel_size.remote()) return self._num_src_pipeline_stage @property def num_dst_pipeline_stage(self): if self._num_dst_pipeline_stage is None: self._num_dst_pipeline_stage = future.get(self.dst_model.replicas[0].all_actors[0].pipeline_model_parallel_size.remote()) return self._num_dst_pipeline_stage @property def num_src_tensor_parallel(self): if self._num_src_tensor_parallel is None: self._num_src_tensor_parallel = future.get(self.src_model.replicas[0].all_actors[0].tensor_model_parallel_size.remote()) return self._num_src_tensor_parallel @property def num_dst_tensor_parallel(self): if self._num_dst_tensor_parallel is None: self._num_dst_tensor_parallel = future.get(self.dst_model.replicas[0].all_actors[0].tensor_model_parallel_size.remote()) return self._num_dst_tensor_parallel @property def num_src_expert_parallel(self): if self._num_src_expert_parallel is None: self._num_src_expert_parallel = future.get(self.src_model.replicas[0].all_actors[0].expert_model_parallel_size.remote()) return self._num_src_expert_parallel @property def num_dst_expert_parallel(self): if self._num_dst_expert_parallel is None: self._num_dst_expert_parallel = future.get(self.dst_model.replicas[0].all_actors[0].expert_model_parallel_size.remote()) return self._num_dst_expert_parallel def setup_collective_group(self): """ Set up collective group for parameter sync. The group consists of all actors from both model's (src & dst) replicas. For P2P Sync, create the group directly. Assigning each actor an uniq rank id. For Broadcast, assigning the rank id only, and defer the group creation during the first parameter sync happens TODO-1: Maybe can remove the P2P sync method in the future for code clarity TODO-2: the _setup_ranks is a public method of DistActor class but named with a leading _ which makes it looks like a private method """ refs = [] # we put src_model first, so we don't need to change the rank of training model models = [self.src_model, self.dst_model] world_size = sum(model.actor_num for model in models) rank_offset = 0 for model in models: for replica in model.replicas: if self._comm_type == PARAM_SYNC_COMM_TYPE.P2P: refs += replica._setup_collective_group(rank_offset, world_size, self.group_name) else: replica._setup_ranks(rank_offset) rank_offset += replica.actor_num if refs: future.get(refs) logger.info(f"init collective group done for {self.group_name}") def destroy_collective_group(self): refs = [] try: refs.extend(self.src_model.destroy_collective_group()) refs.extend(self.dst_model.destroy_collective_group()) future.wait(refs) logger.info(f"destroy_collective_group success for {self.group_name}") except Exception as e: logger.exception(f"destroy_collective_group fail for {self.group_name} {e}") def setup_rank_mapping(self): self.tp_num_mapping = self.num_dst_tensor_parallel // self.num_src_tensor_parallel if self.tp_num_mapping == 1: self.build_rank_mapping() else: self.build_rank_mapping_two_stage() def insert_actor2rank(self, actor, rank: int): if actor not in self.actor2rank: self.actor2rank[actor] = rank def insert_actor2model(self, actor, model): if actor not in self.actor2model: self.actor2model[actor] = model def add_routed_experts_regrouping_actor(self, model, ranks_group: List): for replica_ranks_group in ranks_group: if isinstance(replica_ranks_group[0], list): for tp_ranks in replica_ranks_group: for rank in tp_ranks: actor = model.get_actor(rank) self.insert_actor2rank(actor, rank) self.insert_actor2model(actor, model) else: for rank in replica_ranks_group: actor = model.get_actor(rank) self.insert_actor2rank(actor, rank) self.insert_actor2model(actor, model) # pylint: disable=unused-argument def empty_add_recv_actor(self, src_rank, dst_rank): return def warmup_groups(self): return def add_recv_actor(self, src_rank, dst_rank): src_actor = self.src_model.get_actor(src_rank) self.insert_actor2rank(src_actor, src_rank) self.insert_actor2model(src_actor, self.src_model) dst_actor = self.dst_model.get_actor(dst_rank) self.insert_actor2rank(dst_actor, dst_rank) self.insert_actor2model(dst_actor, self.dst_model) src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) dst_pp_rank = self.get_actor_pipe_rank(dst_actor) src_ep_rank = self.get_actor_ep_rank(src_actor) dst_ep_rank = self.get_actor_ep_rank(dst_actor) logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + f"from tp rank {src_tp_rank} to {dst_tp_rank}, " + f"from ep rank {src_ep_rank} to {dst_ep_rank}.") self.send_recv_actor_mappings[src_actor].append(dst_actor) self.recv_send_actor_mappings[dst_actor].append(src_actor) def add_recv_actor_stage2(self, src_rank, dst_rank): src_actor = self.dst_model.get_actor(src_rank) self.insert_actor2rank(src_actor, src_rank) self.insert_actor2model(src_actor, self.dst_model) # stage 2 sends from dst_model to dst_model dst_actor = self.dst_model.get_actor(dst_rank) self.insert_actor2rank(dst_actor, dst_rank) self.insert_actor2model(dst_actor, self.dst_model) src_gpu = future.get(src_actor.get_visible_gpus.remote()) dst_gpu = future.get(dst_actor.get_visible_gpus.remote()) # TODO(jiangle.jl): support ep/cp. src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) dst_pp_rank = self.get_actor_pipe_rank(dst_actor) logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + \ f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + \ f"from tp rank {src_tp_rank} to {dst_tp_rank}") self.send_recv_actor_mappings_stage2[src_actor].append(dst_actor) self.recv_send_actor_mappings_stage2[dst_actor].append(src_actor) def set_send_actors_to_regroup_routed_experts(self, src_replica_ranks_group): if self.send_actors_to_regroup_routed_experts is None: self.send_actors_to_regroup_routed_experts = [] for src_replica_ranks in src_replica_ranks_group: self.send_actors_to_regroup_routed_experts.append([]) if isinstance(src_replica_ranks[0], list): for src_tp_ranks in src_replica_ranks: self.send_actors_to_regroup_routed_experts[-1].extend( [self.src_model.get_actor(src_rank) for src_rank in src_tp_ranks]) else: self.send_actors_to_regroup_routed_experts[-1].extend( [self.src_model.get_actor(src_rank) for src_rank in src_replica_ranks]) def get_src_and_dst_dp_ranks(self, is_except_routed_experts=False): """ Return: The DP Group List for src & dst model [[DP-0], [DP-1] ... [DP-N]] """ dst_dp_ranks = self.dst_model.all_ranks local_src_ranks = future.get(self.src_model.replicas[0].get_local_param_ranks()) if local_src_ranks[0] is None or dst_dp_ranks is None: if self._debug: logger.warning( f"DEBUG MODE! src_dp_ranks {local_src_ranks} or dst_dp_ranks: {dst_dp_ranks} is None, " "make sure they have values in real application.") return local_src_ranks, dst_dp_ranks else: raise Exception(f"src_dp_ranks {local_src_ranks} or dst_dp_ranks {dst_dp_ranks} should not be None") dp_rank_to_ranks = defaultdict(list) for local_ranks, dp_rank in local_src_ranks: dp_rank_to_ranks[dp_rank].append(local_ranks[dp_rank]) if is_except_routed_experts: # for weight except routed expert, ep_size using for data parallel. # TODO-1 The logic here is a little bit complicate, it would be better to move to a seperate function # TODO-2 The logic here is about HEP, would be better called from class ParameterSyncGroupwithHEP src_hep_size = self.num_src_expert_parallel * self.num_src_tensor_parallel new_dict = defaultdict(list) idx = 0 for dp_rank, values in dp_rank_to_ranks.items(): assert len(values) % src_hep_size == 0, ( f"len of values({len(values)}) for dp_rank {dp_rank} must be divisible by hep size({src_hep_size})" f" when call get_src_and_dst_dp_ranks_for_except_routed_experts." ) pp_blocks = [values[i:i + src_hep_size] for i in range(0, len(values), src_hep_size)] sub_blocks_per_pp = [] for block in pp_blocks: sub_block_size = src_hep_size // self.num_src_expert_parallel sub_blocks = [block[i:i + sub_block_size] for i in range(0, src_hep_size, sub_block_size)] sub_blocks_per_pp.append(sub_blocks) for i in range(self.num_src_expert_parallel): merged_group = [] for sub_blocks in sub_blocks_per_pp: merged_group.extend(sub_blocks[i]) new_dict[idx].extend(merged_group) idx += 1 src_dp_ranks = [i[1] for i in sorted(new_dict.items())] else: src_dp_ranks = [i[1] for i in sorted(dp_rank_to_ranks.items())] return src_dp_ranks, dst_dp_ranks def get_load_balance_dst_rank( self, lb_dst_offset_pq_dict, s_idx, start, src_rank, dst_replica_ranks_group, d_idx, pre_allocate=False ): """Get the dst_rank for load balance when gpu collides. """ dst_tp_indices = sorted([ s_idx * self.tp_num_mapping + (start + i) % self.tp_num_mapping for i in range(self.tp_num_mapping) ]) indexed_dst_tp_group = tuple(dst_replica_ranks_group[d_idx][dst_tp_index] for dst_tp_index in dst_tp_indices) # Construct a priority queue (PQ) to retrieve `dst_rank` for load balancing when gpu collides. # The key of the PQ is (hit_time, max_seq_num), meaning that the rank is used for `hit_time` times, # while `max_seq_num` further sorts `dst_rank` when `hit_time` remains the same. if indexed_dst_tp_group not in lb_dst_offset_pq_dict: pq = PriorityQueue() max_seq_num = 0 hit_time = 0 while max_seq_num < self.tp_num_mapping: pq.put(( hit_time, max_seq_num, s_idx * self.tp_num_mapping + (start + max_seq_num) % self.tp_num_mapping )) max_seq_num += 1 lb_dst_offset_pq_dict[indexed_dst_tp_group] = [pq, max_seq_num] else: max_seq_num = lb_dst_offset_pq_dict[indexed_dst_tp_group][1] # Each time, we retrieve the first value of the PQ. # 1. If the first `dst_rank` will encounter gpu collision with `src_rank`, we retrieve it, set `seq_num` to # `max_seq_num`, increase `max_seq_num` by 1, and finally insert <(hit_time, seq_num), offset> back # to the PQ. # 2. If the first `dst_rank` won't encounter gpu collision with `src_rank`, we insert it to another PQ (called # legal_lb_recv_offset_pq). After looping through all legal solutions, we will retrieve the first load-balance one. # 3. If we cannot find a legal solution after `self.tp_num_mapping` times, all `src_rank` will encounter gpu collision # with `dst_rank`, we throw a runtime exception. lb_recv_offset_pq = lb_dst_offset_pq_dict[indexed_dst_tp_group][0] legal_lb_recv_offset_pq = PriorityQueue() assert len(lb_recv_offset_pq.queue) == self.tp_num_mapping, ( "length of the load-balance recv_offset priority queue must be equal to tp_num_mapping, " f"got {len(lb_recv_offset_pq.queue)} and {self.tp_num_mapping}." ) is_collide = False for _ in range(self.tp_num_mapping): hit_time, seq_num, offset = lb_recv_offset_pq.get() dst_rank = dst_replica_ranks_group[d_idx][offset] logger.debug(f"Trying to match {src_rank} and {dst_rank} (hit={hit_time}), remaining queue={lb_recv_offset_pq.queue})") src_actor = self.src_model.get_actor(src_rank) dst_actor = self.dst_model.get_actor(dst_rank) if self.is_same_gpu(src_actor, dst_actor): logger.info( f"src_rank ({src_rank}) will share the same gpu with dst_rank ({dst_rank}). " "This is not allowed in NCCL send-recv. ChatLearn will skip dst_rank to the next legal one." ) is_collide = True lb_recv_offset_pq.put((hit_time, max_seq_num, offset)) max_seq_num += 1 lb_dst_offset_pq_dict[indexed_dst_tp_group][1] = max_seq_num else: legal_lb_recv_offset_pq.put((hit_time, seq_num, offset)) logger.debug(f"legal_lb_recv_offset_pq={legal_lb_recv_offset_pq.queue}") # if pre_allocate is True and no collide, we directly return if pre_allocate is True and is_collide is False: while len(legal_lb_recv_offset_pq.queue) > 0: lb_recv_offset_pq.put(legal_lb_recv_offset_pq.get()) return None, False # there must be at least one legal recv offset if len(legal_lb_recv_offset_pq.queue) == 0: raise RuntimeError( f"Rank mapping solution is infeasible because src_rank ({src_rank}) will collide with all candidates." ) # extract the first legal one to keep load balance hit_time, seq_num, offset = legal_lb_recv_offset_pq.get() lb_recv_offset_pq.put((hit_time + 1, seq_num, offset)) # put other solutions back to lb_recv_offset_pq while len(legal_lb_recv_offset_pq.queue) > 0: lb_recv_offset_pq.put(legal_lb_recv_offset_pq.get()) logger.debug(f"after retrieving, lb_recv_offset_pq = {lb_recv_offset_pq.queue}") # return dst_rank dst_rank = dst_replica_ranks_group[d_idx][offset] return dst_rank, is_collide def build_rank_mapping(self, add_recv_actor_fn=None): """ setup rank mapping for src parameter and dst parameter get rank for one src_model, without model replicas for each DP Group: for each TP & EP Group: for each PP Group: mapping[src] = dst """ if add_recv_actor_fn is None: add_recv_actor_fn = self.add_recv_actor src_dp_ranks, dst_dp_ranks = self.get_src_and_dst_dp_ranks() if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: replica_rank_iter = cycle(iter(src_dp_ranks)) logger.debug(f"src_dp_ranks: {src_dp_ranks}") logger.debug(f"dst_dp_ranks: {dst_dp_ranks}") assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 def split_ranks_by_tp_and_ep_size(ranks, tp_size : int = 1, ep_size : int = 1): tp_and_ep_size = tp_size * ep_size return [ranks[i:i + tp_and_ep_size] for i in range(0, len(ranks), tp_and_ep_size)] for dst_replica_ranks in dst_dp_ranks: src_replica_ranks = next(replica_rank_iter) src_replica_ranks_group = split_ranks_by_tp_and_ep_size(src_replica_ranks, self.num_src_tensor_parallel, self.num_src_expert_parallel) dst_replica_ranks_group = split_ranks_by_tp_and_ep_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel) self.set_send_actors_to_regroup_routed_experts(src_replica_ranks_group) pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage for i, src_tp_group in enumerate(src_replica_ranks_group): j = i // pipe_map_interval for src_rank, dst_rank in zip(src_tp_group, dst_replica_ranks_group[j]): add_recv_actor_fn(src_rank, dst_rank) # pylint: disable=unused-argument def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): # Currently, we do not support build rank mapping for expert parallelism raise NotImplementedError("ChatLearn does not support build rank mapping from Megatron-LM for expert parallelism") def build_rank_mapping_two_stage(self, add_recv_actor_fn=None): # setup rank mapping for src parameter and dst parameter # get rank for one src_model, without model replicas if add_recv_actor_fn is None: add_recv_actor_stage1_fn = self.add_recv_actor add_recv_actor_stage2_fn = self.add_recv_actor_stage2 else: assert len(add_recv_actor_fn) == 2, ( "The length of add_recv_actor_fn should be 2. The first one is a function handler for communication stage 1, " "while the second one is a function handler for communication stage 2." ) add_recv_actor_stage1_fn = add_recv_actor_fn[0] add_recv_actor_stage2_fn = add_recv_actor_fn[1] src_ranks, dst_ranks = self.get_src_and_dst_dp_ranks(is_except_routed_experts=True) if self._debug and (src_ranks[0] is None or dst_ranks is None): return replica_rank_iter = cycle(iter(src_ranks)) logger.debug(f"src_ranks: {src_ranks}") logger.debug(f"dst_ranks: {dst_ranks}") assert self.num_dst_tensor_parallel % self.num_src_tensor_parallel == 0, \ "currently we require mod value equals to zero for tensor_model_parallel_size of dst_model and that of src_model while " + \ f"src model {self.src_model.name}(TP={self.num_src_tensor_parallel}) and " + \ f"dst model {self.dst_model.name}(TP={self.num_dst_tensor_parallel})" assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 def split_ranks_by_tp_and_ep_size(ranks, tp_size, ep_size): if ep_size > 1: sort_ranks_on_grouped_tp = [] index = 0 tp_index = 0 for _ in range(len(ranks)): sort_ranks_on_grouped_tp.append(index) if tp_index < tp_size - 1: index += 1 tp_index += 1 else: start_index = index + 1 - tp_size index = start_index + (ep_size * tp_size) tp_index = 0 if index >= len(ranks): index = (index + tp_size) % len(ranks) else: sort_ranks_on_grouped_tp = ranks return [sort_ranks_on_grouped_tp[i:i + tp_size] for i in range(0, len(sort_ranks_on_grouped_tp), tp_size)] pair_list = [] p2p_list = [] src_replica_offset = 0 lb_dst_offset_pq_dict = {} for dst_replica_ranks in dst_ranks: src_replica_ranks = next(replica_rank_iter) # for weight except routed expert, ep_size using for data parallel. src_replica_ranks_group = split_ranks_by_tp_and_ep_size(src_replica_ranks, self.num_src_tensor_parallel, 1) dst_replica_ranks_group = split_ranks_by_tp_and_ep_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel) logger.debug(f"src_replica_ranks_group: {src_replica_ranks_group}") logger.debug(f"dst_replica_ranks_group: {dst_replica_ranks_group}") pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage assert pipe_map_interval >= 1, \ f"dst_pp expected to divide src_pp, while src_pp {self.num_src_pipeline_stage} and dst_pp {self.num_dst_pipeline_stage}" # stage 1: comm pairs that broadcast params from trainer to inference model # Each rank in trainer holds weights for tp_num_mapping ranks in inference model. # For example: trainer_tp = 2, inference_tp = 4 => tp_num_mapping = inference_tp // trainer_tp = 2 # Weight mapping from training to inference: # [0] -> [0', 1'] # [1] -> [2', 3'] # To avoid p2p communication on the same gpu, we only broadcast params to first rank in weight_mapping_group. # Comm mapping from training to inference: # [0] -> [0'] # [1] -> [2'] # Firstly, pre-allocate for those gpu collisions uncollided_index_to_start_j = {} for i, src_tp_group in enumerate(src_replica_ranks_group): if i < src_replica_offset: continue j = (i - src_replica_offset) // pipe_map_interval if j == self.num_dst_pipeline_stage: src_replica_offset = i break if self.tp_num_mapping == 1: start = 0 else: mod_i = (i - src_replica_offset) % self.tp_num_mapping start = mod_i if (i - src_replica_offset) < self.tp_num_mapping else (self.tp_num_mapping - mod_i - 1) % self.tp_num_mapping for s_idx, src_rank in enumerate(src_tp_group): dst_rank, is_collide = self.get_load_balance_dst_rank( lb_dst_offset_pq_dict, s_idx, start, src_rank, dst_replica_ranks_group, j, pre_allocate=True ) if is_collide: add_recv_actor_stage1_fn(src_rank, dst_rank) pair_list.append((src_rank, dst_rank)) else: assert dst_rank is None uncollided_index_to_start_j.update({(i, s_idx) : (start, j)}) # Then, allocate src_ranks without gpu collisions for i, src_tp_group in enumerate(src_replica_ranks_group): for s_idx, src_rank in enumerate(src_tp_group): if (i, s_idx) not in uncollided_index_to_start_j: continue start, j = uncollided_index_to_start_j.get((i, s_idx)) dst_rank, _ = self.get_load_balance_dst_rank( lb_dst_offset_pq_dict, s_idx, start, src_rank, dst_replica_ranks_group, j, pre_allocate=False ) add_recv_actor_stage1_fn(src_rank, dst_rank) pair_list.append((src_rank, dst_rank)) # stage 2: comm pairs that broadcast params from first rank to the other ranks for each weight_mapping_group # Comm mapping in each weight_mapping_group of inference: # [0'] -> [1'] # [2'] -> [3'] recv_ranks = [pair[1] for pair in pair_list] def p2p_pair_grouping(tuples): for s_idx, src_rank in enumerate(tuples): for d_idx, dst_rank in enumerate(tuples): if s_idx == d_idx or src_rank not in recv_ranks: # pylint: disable=cell-var-from-loop continue add_recv_actor_stage2_fn(src_rank, dst_rank) p2p_list.append((src_rank, dst_rank)) for dst_tp_group in dst_replica_ranks_group: dst_tp_group = split_ranks_by_tp_and_ep_size(dst_tp_group, self.tp_num_mapping, 1) for tuples in dst_tp_group: p2p_pair_grouping(tuples) logger.info(f"comm pair_list <train_rank, inference_rank>: {pair_list}") logger.info(f"comm p2p_list <inference_rank, inference_rank>: {p2p_list}") def _clear_sync_send_recv_parameters(self, rank_mappings:List): if len(rank_mappings) == 0: return refs = [] flagged_actors = set() for rank_mapping in rank_mappings: if len(rank_mapping) == 0: continue for send_actor, recv_actors in rank_mapping.items(): if send_actor not in flagged_actors: refs.append(send_actor.clear_sync_send_recv_parameters.remote()) flagged_actors.add(send_actor) for recv_actor in recv_actors: if recv_actor not in flagged_actors: refs.append(recv_actor.clear_sync_send_recv_parameters.remote()) flagged_actors.add(recv_actor) future.get(refs) def _clear_send_recv_param_names(self): self._send_recv_param_names = {} def _clear_sorted_send_actors(self, sorted_send_actors_list:List): if len(sorted_send_actors_list) == 0: return for sorted_send_actors in sorted_send_actors_list: if sorted_send_actors is not None: sorted_send_actors = None def clear_cache(self, sorted_send_actors_list=None, rank_mapping_list=None): if sorted_send_actors_list is None: sorted_send_actors_list = [ self.send_actors_to_regroup_routed_experts, self.sorted_send_actors, self.sorted_send_actors_stage2 ] if rank_mapping_list is None: rank_mapping_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] self._clear_sync_send_recv_parameters(rank_mapping_list) self._clear_send_recv_param_names() self._clear_sorted_send_actors(sorted_send_actors_list) def validate_sync_results(self, send_actor, recv_actors, requires_grad, filter_fn=None, param_group="default"): assert param_group in ("default", "routed", "except_routed"), ( f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}." ) def validate(): src_names, dst_names = self.set_sync_param_names(send_actor, recv_actors[0], requires_grad, filter_fn, param_group) # check the value of src model and tgt model pipe_stage = self.get_actor_pipe_rank(send_actor) res = [send_actor.reset_sync_parameters.remote(src_names, pipe_stage)] for recv_actor in recv_actors: res.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)) future.wait(res) src_names, dst_names = future.get([send_actor.get_parameter_to_sync_names.remote(pipe_stage), recv_actors[0].get_parameter_to_sync_names.remote(pipe_stage)]) assert len(src_names) == len(dst_names), ( f"expect the length of src_names and dst_names being the same, got {len(src_names)} and {len(dst_names)}" ) # check the value of src model and tgt model names = list(zip(src_names, dst_names)) for src_name, dst_name in tqdm(names): if param_group in ("default", "except_routed"): src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True, self.tp_num_mapping > 1)) elif param_group == "routed": src_tensor = future.get(send_actor.get_parameter_to_sync.remote(src_name, pipe_stage, True)) if src_tensor.isnan().any(): raise RuntimeError(f"weight {src_name} from send actor is nan, please check checkpoint or training process.") src_tensor_shape = src_tensor.shape for recv_actor in recv_actors: dst_tensor = future.get(recv_actor.get_parameter_to_sync.remote(dst_name, pipe_stage, True)) if dst_tensor.isnan().any(): raise RuntimeError(f"weight {dst_name} in recv actor is nan, please check param sync.") if param_group in ("default", "except_routed"): if self.tp_num_mapping == 1: # for trainer_tp == inference_tp assert src_tensor.shape == dst_tensor.shape, ( f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match." ) assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), ( f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match." ) else: # for inference_tp % trainer_tp == 0 and inference_tp > trainer_tp dst_tensor_shape = dst_tensor.shape src_tensor = src_tensor.reshape(-1) dst_tensor = dst_tensor.reshape(-1) tp_slice = self.actor2rank[recv_actor] % self.tp_num_mapping if src_tensor.shape == dst_tensor.shape: src_tensor_slice = src_tensor else: assert ( src_tensor.shape[0] % dst_tensor.shape[0] == 0 and src_tensor.shape[0] // dst_tensor.shape[0] == self.tp_num_mapping ), ( f"num of elements in src_tensor must be divided by that of dst_tensor. " f"while src {src_name}: {src_tensor_shape} and dst {dst_name}: {dst_tensor_shape}." ) start = dst_tensor.shape[0] * tp_slice end = start + dst_tensor.shape[0] src_tensor_slice = src_tensor[start:end] assert torch.allclose(src_tensor_slice, dst_tensor, atol=1e-06), ( f"after weight sync {src_name}_{tp_slice}: " f"{src_tensor_slice.view(dst_tensor_shape)} and {dst_name}: {dst_tensor.view(dst_tensor_shape)} do not match." ) elif param_group == "routed": assert self.hep_num_mapping == 1 assert src_tensor.shape == dst_tensor.shape, ( f"after weight sync {src_name}: {src_tensor.shape} and {dst_name}: {dst_tensor.shape} do not match." ) assert torch.allclose(src_tensor, dst_tensor, atol=1e-06), ( f"after weight sync {src_name}: {src_tensor} and {dst_name}: {dst_tensor} do not match." ) return True logger.info("Going to validate transmitted tensors...") validate() logger.info("Validation passed!") def set_sync_param_names_stage2(self, send_actor, recv_actor, to_rank, requires_grad, filter_fn=None, param_group="default"): send_names, _ = self.set_sync_param_names(send_actor, send_actor, requires_grad, filter_fn, param_group) refs = [] refs.append(send_actor.set_send_parameters.remote(send_names, self.get_actor_pipe_rank(send_actor))) refs.append(recv_actor.set_recv_parameters.remote(to_rank, send_names, self.get_actor_pipe_rank(recv_actor))) future.get(refs) return send_names, send_names def sync_broadcast_two_stage(self, actors, group_name, requires_grad=None, stage2=False, filter_fn=None, param_group="default"): send_actor = actors[0] start_time = time.time() stage_str = "STAGE1" if stage2 is False else "STAGE2" for rank, recv_actor in enumerate(actors[1:]): if stage2: self.set_sync_param_names_stage2(send_actor, recv_actor, self.actor2rank[recv_actor], requires_grad, filter_fn, param_group) else: self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) shape_refs = [] shape_refs.append(send_actor.get_parameter_shape.remote(pipe_stage)) shape_refs.append(recv_actor.get_parameter_shape.remote(pipe_stage)) send_shape_list, recv_shape_list = future.get(shape_refs) buffer_num = {} tp_division = {} for send_name_and_shape, recv_name_and_shape in zip(send_shape_list, recv_shape_list): send_param_num = send_name_and_shape[1].numel() recv_param_num = recv_name_and_shape[1].numel() # for group query attention, tensor might consist of tp part and dp part. ele_buffer_num = 1 if send_param_num == recv_param_num else self.tp_num_mapping buffer_num[recv_name_and_shape[0]] = ele_buffer_num tp_division[send_name_and_shape[0]] = ele_buffer_num refs = [] refs.append(recv_actor.set_tp_num_mapping.remote(self.tp_num_mapping)) refs.append(recv_actor.set_buffer_num.remote(buffer_num)) refs.append(send_actor.set_tp_num_mapping.remote(self.tp_num_mapping)) refs.append(send_actor.set_tp_division.remote(tp_division)) future.get(refs) refs = [] pipe_stage = self.get_actor_pipe_rank(send_actor) send_rank = 0 if stage2: assert len(actors) == 2, f"expect only 2 actors for stage2. \ sync params of relative rank to other slices of inference model. while {len(actors)}" for rank, actor in enumerate(actors): sync_buffer_rank = self.actor2rank[actors[1]] if rank == 0 and stage2 else 0 ref = actor.broadcast_parameter_two_stage.remote( self.actor2rank[actor], sync_buffer_rank, rank, send_rank, group_name, pipe_stage, stage2) refs.append(ref) rets = future.wait(refs, return_output=True) logger.info(f"sync_broadcast_two_stage done {stage_str} {group_name} using {time.time()-start_time} seconds") return rets def sync_broadcast(self, actors, group_name, requires_grad=None, filter_fn=None, param_group="default"): send_actor = actors[0] for recv_actor in actors[1:]: self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) refs = [] for rank, actor in enumerate(actors): ref = actor.broadcast_parameter.remote(rank, 0, group_name, pipe_stage) refs.append(ref) future.wait(refs, return_output=True) def sync_allgather(self, actors, group_name, requires_grad=None, filter_fn=None): # Currently, only routed experts are to be all-gathered. for actor in actors: self.set_sync_param_names(actor, actor, requires_grad, filter_fn, param_group="routed", should_map_name=False) pipe_stage = self.get_actor_pipe_rank(actors[0]) refs = [] for actor in actors: ref = actor.allgather_routed_expert_parameter.remote(group_name, pipe_stage) refs.append(ref) future.wait(refs, return_output=True) def sync_alltoall(self, actors, requires_grad=None, filter_fn=None): # Currently, only routed experts are to be synced with all-to-all. for actor in actors: self.set_sync_param_names(actor, actor, requires_grad, filter_fn, param_group="routed", should_map_name=False) pipe_stage = self.get_actor_pipe_rank(actors[0]) refs = [] logger.info(f"apply alltoall among {[self.actor2rank[actor] for actor in actors]}") for actor in actors: ref = actor.alltoall_routed_expert_parameter.remote(pipe_stage) refs.append(ref) future.wait(refs, return_output=True) def _sync_send_recv(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default"): self.set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group) pipe_stage = self.get_actor_pipe_rank(send_actor) is_the_same_gpu = self.is_same_gpu(send_actor, recv_actor) if is_the_same_gpu: name2ref = send_actor.ray_put_parameter.remote(self.group_name, pipe_stage) recv_ref = recv_actor.ray_get_parameter.remote(self.group_name, name2ref, pipe_stage) future.get(recv_ref) else: send_ref = send_actor.send_parameter.remote(self.actor2rank[recv_actor], self.group_name, pipe_stage) recv_ref = recv_actor.recv_parameter.remote(self.actor2rank[send_actor], self.group_name, pipe_stage) future.get([send_ref, recv_ref]) logger.debug(f"sync all parameters from {send_actor} to {recv_actor}") def sync_send_recv(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default"): try: self._sync_send_recv(send_actor, recv_actor, requires_grad, filter_fn, param_group) except Exception: future.get(self.error_signal.set.remote(traceback.format_exc())) def check_param_names(self, send_actor, recv_actor, src_names, dst_names): ref0 = send_actor.check_param_exists.remote(src_names) ref1 = recv_actor.check_param_exists.remote(dst_names) states = future.get([ref0, ref1]) if not states[0]: raise RuntimeError(f"Check src parameters to sync fail {src_names}") if not states[1]: raise RuntimeError(f"Check dst parameters to sync fail {dst_names}") def get_actor_pipe_rank(self, actor): def inner_func(): return future.get(actor.pipeline_parallel_rank.remote()) return utils.get_or_cache(self._actor2pipe, actor, inner_func) def get_actor_tp_rank(self, actor): def inner_func(): return future.get(actor.tensor_parallel_rank.remote()) return utils.get_or_cache(self._actor2tp, actor, inner_func) def get_actor_ep_rank(self, actor): def inner_func(): return future.get(actor.expert_parallel_rank.remote()) return utils.get_or_cache(self._actor2ep, actor, inner_func) def get_actor_dp_rank(self, actor): def inner_func(): return future.get(actor.get_data_parallel_rank.remote()) return utils.get_or_cache(self._actor2dp, actor, inner_func) def _set_sync_param_names(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default", should_map_name=True): if requires_grad is None: requires_grad = True if self._enable_lora: # TODO(jiangle.jl): support freeze layer. requires_grad = False assert param_group in ("default", "routed", "except_routed"), ( f"param_group must be one of 'default', 'routed', or 'except_routed', got {param_group}." ) if self.num_src_pipeline_stage > 1: dst_pipe_rank = self.get_actor_pipe_rank(recv_actor) dst_layer_offset = self.get_or_cache(recv_actor, "get_pipeline_stage_layer_offset") dst_src_mappings = future.get(send_actor.build_pipeline_layer_name_mapping.remote( self.num_dst_pipeline_stage, dst_pipe_rank, dst_layer_offset, requires_grad=requires_grad)) dst_names = list(dst_src_mappings.keys()) src_names = list(dst_src_mappings.values()) else: src_names = dst_names = future.get(send_actor.get_parameter_names.remote(requires_grad=requires_grad)) if self._enable_lora: src_names = [ele for ele in src_names if LORA_WEIGHT_PREFIX not in ele] dst_names = [ele for ele in dst_names if LORA_WEIGHT_PREFIX not in ele] if filter_fn is not None: src_names = filter_fn(src_names) dst_names = filter_fn(dst_names) synchronizer = get_synchronizer(self.src_model, self.dst_model) if should_map_name: src_names, dst_names = synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names) else: # For routed experts which need to regroup expert first in trainer actors. synchronizer.map_name_from_src_to_dst(send_actor, recv_actor, src_names, dst_names) self.actor2synchronizer[send_actor] = synchronizer future.wait(send_actor.set_synchronizer.remote(synchronizer)) self.check_param_names(send_actor, recv_actor, src_names, dst_names) dst_model = self.actor2model[recv_actor] if self.tp_num_mapping > 1 and ((not dst_model.use_vllm_backend and param_group != "routed") or dst_model.use_vllm_backend): key = (recv_actor, recv_actor, param_group) if key not in self._send_recv_param_names: self._send_recv_param_names[key] = (dst_names, dst_names) else: dst_names0 = self._send_recv_param_names[key][0] dst_names0 += dst_names self._send_recv_param_names[key] = (dst_names0, dst_names0) if not self.synchronizer.is_parameter_changed: pipe_stage = self.get_actor_pipe_rank(send_actor) refs = [] refs.append(send_actor.set_sync_parameters.remote(src_names, pipe_stage)) refs.append(recv_actor.set_sync_parameters.remote(dst_names, pipe_stage)) future.get(refs) return src_names, dst_names def set_sync_param_names(self, send_actor, recv_actor, requires_grad=None, filter_fn=None, param_group="default", should_map_name=True): src_names, dst_names = utils.get_or_cache(self._send_recv_param_names, (send_actor, recv_actor, param_group), \ lambda: self._set_sync_param_names(send_actor, recv_actor, requires_grad, filter_fn, param_group, should_map_name)) logger.debug(f"{self.actor2rank[send_actor]} -> {self.actor2rank[recv_actor]}: {src_names[:5]} -> {dst_names[:5]}") pipe_stage = self.get_actor_pipe_rank(send_actor) refs = [] refs.append(send_actor.reset_sync_parameters.remote(src_names, pipe_stage)) refs.append(recv_actor.reset_sync_parameters.remote(dst_names, pipe_stage)) future.get(refs) return src_names, dst_names def create_broadcast_group(self, send_actor, recv_actors, group_name=None, param_group="default"): actor_groups = [send_actor] actor_groups.extend(recv_actors) # Use self.actor2rank to ensure a globally unique number within a param_group. send_actor_rank = self.actor2rank[send_actor] recv_actor_ranks = '_'.join([str(self.actor2rank[actor]) for actor in recv_actors]) # Always include self.group_name to ensure the name of a param_group is unique. if group_name is None: group_name = self.group_name elif not group_name.startswith(self.group_name + "_"): group_name = self.group_name + "_" + group_name finalized_group_name = f"{group_name}_{param_group}_from_{send_actor_rank}_to_{recv_actor_ranks}" logger.debug(f"finalized_group_name is {finalized_group_name}") logger.debug(f"current collevtive_groups is {self.collective_groups}") logger.debug(f"send_actor: {send_actor}, recv_actors: {recv_actors}") if finalized_group_name not in self.collective_groups: refs = [] for rank, actor in enumerate(actor_groups): ref = actor.setup_collective_group.remote(rank, len(actor_groups), "nccl", finalized_group_name) refs.append(ref) future.wait(refs) self.collective_groups.append(finalized_group_name) self.groups2actors[finalized_group_name] = tuple(actor_groups) return actor_groups, finalized_group_name def create_allgather_group(self, actor_groups, group_name=None): # Use self.actor2rank to ensure a globally unique number within a param_group. actor_ranks = '_'.join([str(self.actor2rank[actor]) for actor in actor_groups]) # Always include self.group_name to ensure the name of a param_group is unique. if group_name is None: group_name = self.group_name elif not group_name.startswith(self.group_name + "_"): group_name = self.group_name + "_" + group_name finalized_group_name = f"{group_name}_routed_among_{actor_ranks}" logger.debug(f"finalized_group_name is {finalized_group_name}") logger.debug(f"current collevtive_groups is {self.collective_groups}") if finalized_group_name not in self.collective_groups: refs = [] for rank, actor in enumerate(actor_groups): ref = actor.setup_collective_group.remote(rank, len(actor_groups), "nccl", finalized_group_name) refs.append(ref) future.wait(refs) self.collective_groups.append(finalized_group_name) return actor_groups, finalized_group_name def sort_send_actors(self, send_recv_actor_mappings, sorted_send_actors): if sorted_send_actors is not None: return sorted_send_actors dp2send_actors = defaultdict(list) for send_actor in send_recv_actor_mappings: dp2send_actors[self.get_actor_dp_rank(send_actor)].append(send_actor) for dp_rank in dp2send_actors: send_actors = dp2send_actors[dp_rank] dp2send_actors[dp_rank] = sorted(send_actors, key=lambda x: self.actor2rank[x]) sorted_send_actors = [] dp_rank = 0 while len(sorted_send_actors) < len(send_recv_actor_mappings): sorted_send_actors.append(dp2send_actors[dp_rank].pop(0)) dp_rank += 1 # dp_rank not in dp2send_actors happens when inference replica number less than training replica number if dp_rank == self.src_dp_size or dp_rank not in dp2send_actors: dp_rank = 0 assert len(send_recv_actor_mappings) == len(sorted_send_actors) return sorted_send_actors def sync_broadcast_second_stage_internal(self, group_name, thread_group, requires_grad=None, filter_fn=None, param_group="default", dryrun=False): max_workers = len(thread_group) logger.info(f"Use {max_workers} workers for second_stage_internal broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for idx, actor_group in enumerate(thread_group): send_actor, recv_actor = actor_group group_name_with_idx = f"{group_name}_{idx}" actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, [recv_actor], group_name=group_name_with_idx, param_group=param_group ) if dryrun: continue futures.append(executor.submit( self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, True, filter_fn, param_group)) for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: traceback.print_exc() raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) def sync_broadcast_second_stage(self, group_name, thread_groups, requires_grad=None, filter_fn=None, param_group="default", dryrun=False): tp_size = self.num_dst_tensor_parallel num_thread_groups = len(thread_groups) // tp_size new_thread_groups = [thread_groups[tp_size*i:tp_size*(i+1)] for i in range(num_thread_groups)] if not new_thread_groups: new_thread_groups = [thread_groups] max_workers = 1 logger.info(f"Use {max_workers} workers for second_stage broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for idx, thread_group in enumerate(new_thread_groups): group_name_with_idx = f"{group_name}_{idx}" futures.append(executor.submit( self.sync_broadcast_second_stage_internal, group_name_with_idx, thread_group, requires_grad, filter_fn, param_group, dryrun)) for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: traceback.print_exc() raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) def sync_broadcast_multi_threads( self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None, group_name=None, stage2=False, filter_fn=None, param_group="default", dryrun=False): if stage2: thread_group = [] for send_actor in sorted_send_actors: recv_actors = send_recv_actor_mappings[send_actor] for recv_actor in recv_actors: thread_group.append((send_actor, recv_actor)) actor_groups_to_sync = [] for group in thread_group: new_actor_group_flag = True for idx, actor_groups in enumerate(actor_groups_to_sync): in_actor_group = False for actor_group in actor_groups: if group[0] in actor_group or group[1] in actor_group: in_actor_group = True if not in_actor_group: new_actor_group_flag = False actor_groups_to_sync[idx].append(group) #pylint: disable=unnecessary-list-index-lookup break if new_actor_group_flag or not actor_groups_to_sync: actor_groups_to_sync.append([group]) for group_idx, actor_groups in enumerate(actor_groups_to_sync): if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: self.sync_broadcast_second_stage( f"{group_name}_{group_idx}", actor_groups, requires_grad, filter_fn, param_group, dryrun=dryrun ) else: raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") else: max_workers = len(sorted_send_actors) logger.info(f"Use {max_workers} workers for first_stage broadcasting.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for send_actor in sorted_send_actors: recv_actors = send_recv_actor_mappings[send_actor] if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: actor_groups, finalized_group_name = self.create_broadcast_group( send_actor, recv_actors, group_name=group_name, param_group=param_group ) if not dryrun: futures.append(executor.submit( self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group )) else: raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.") for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: traceback.print_exc() raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) def sync_allgather_multi_threads( self, send_actors, max_workers=1, requires_grad=None, group_name=None, filter_fn=None ): send_actors_to_allgather_routed_experts = send_actors[0] logger.info(f"Use {max_workers} workers for allgather multiprocessing.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for allgather_actors in send_actors_to_allgather_routed_experts: actor_groups, finalized_group_name = self.create_allgather_group(allgather_actors, group_name=group_name) futures.append(executor.submit( self.sync_allgather, actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn )) for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) def sync_alltoall_multi_threads( self, send_actors, max_workers=1, requires_grad=None, filter_fn=None ): send_actors_to_alltoall_routed_experts = send_actors[0] max_workers = len(send_actors_to_alltoall_routed_experts) logger.info(f"Use {max_workers} workers for alltoall multiprocessing.") with ThreadPoolExecutor(max_workers=max_workers) as executor: futures = [] for actor_groups in send_actors_to_alltoall_routed_experts: futures.append(executor.submit( self.sync_alltoall, actor_groups, requires_grad, filter_fn=filter_fn )) for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) def check_and_setup_collective_group(self): if not self._is_collective_group_created: # Re-create collective group only when it is destroyed before. assert self._free_sync_collective_group self.setup_collective_group() def check_and_destroy_collective_group(self): if self._free_sync_collective_group: self.destroy_collective_group() self._is_collective_group_created = False self.collective_groups = [] self.groups2actors = {} def check_and_fuse_lora(self, enable_lora, actor_mapping): send_actors_set = set() def check_and_fuse_lora_internal(actor_mapping_item): for send_actor in actor_mapping_item: if enable_lora and send_actor not in send_actors_set: ref = send_actor.fuse_lora_layer.remote() state = future.get([ref]) assert state, "Check fuse lora layer fail." send_actors_set.add(send_actor) if isinstance(actor_mapping, List): for actor_mapping_item in actor_mapping: if actor_mapping_item is None: continue check_and_fuse_lora_internal(actor_mapping_item) elif isinstance(actor_mapping, Dict): if actor_mapping is None: return check_and_fuse_lora_internal(actor_mapping) else: raise ValueError("unrecognized type for actor_mapping, expect: List or Dict") def check_and_unfuse_lora(self, enable_lora, actor_mapping): send_actors_set = set() def check_and_unfuse_lora_internal(actor_mapping_item): for send_actor in actor_mapping_item: if self._enable_lora and send_actor not in send_actors_set: ref = send_actor.unfuse_lora_layer.remote() state = future.get([ref]) assert state, "Check unfuse lora layer fail." send_actors_set.add(send_actor) if isinstance(actor_mapping, List): for actor_mapping_item in actor_mapping: if actor_mapping_item is None: continue check_and_unfuse_lora_internal(actor_mapping_item) elif isinstance(actor_mapping, Dict): if actor_mapping is None: return check_and_unfuse_lora_internal(actor_mapping) else: raise ValueError("unrecognized type for actor_mapping, expect: List or Dict") def validate_sync_results_parallel(self, actor_mappings_list:List, requires_grad=None, validate=False, filter_fn=None, param_group="default"): if self._debug or validate: assert len(actor_mappings_list) in (1, 2), f"The length of actor mapping list should be 1 or 2, but got {len(actor_mappings_list)}." args = [] for send_actor, recv_actors in actor_mappings_list[0].items(): for recv_actor in recv_actors: if len(actor_mappings_list) == 1: args.append((send_actor, [recv_actor], requires_grad, filter_fn, param_group)) elif len(actor_mappings_list) == 2: recv_actors_stage2 = actor_mappings_list[1].get(recv_actor, []) args.append((send_actor, [recv_actor] + recv_actors_stage2, requires_grad, filter_fn, param_group)) if self._debug: for arg in args: self.validate_sync_results(arg[0], arg[1], arg[2], arg[3], arg[4]) else: execute_in_parallel(self.validate_sync_results, args) def _calculate_max_workers(self, sorted_send_actors, actor_mappings=None): max_workers = get_args().runtime_args.param_sync_max_workers if max_workers is None: max_workers = max(self.src_model.total_gpu // self.num_src_pipeline_stage, 1) if max_workers == -1: if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: max_workers = len(sorted_send_actors) else: assert actor_mappings is not None, ( "actor_mappings should not be None when max_workers is -1 and " "communication type for parameter synchronization is not broadcast." ) max_workers = len(sorted_send_actors) * len(actor_mappings[sorted_send_actors[0]]) return max_workers def _multi_thread_sync_for_tp_num_mapping_gt_1( self, send_actors:List, actor_mappings:List, requires_grad=None, filter_fn=None, param_group="default", dryrun=False ): assert len(send_actors) == 2, ( f"Expect the length of send_actors being 2 for TP num mapping greater than 1, but got {len(send_actors)}." ) send_actors_stage1 = send_actors[0] # pylint: disable=unused-variable send_actors_stage2 = send_actors[1] # pylint: disable=unused-variable assert len(actor_mappings) == 2, ( f"Expect the length of actor_mappings being 2 for TP num mapping greater than 1, but got {len(actor_mappings)}." ) actor_mappings_stage1 = actor_mappings[0] actor_mappings_stage2 = actor_mappings[1] # stage 1 self.timers("stage1").start() sorted_send_actors_stage1 = list(actor_mappings_stage1.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage1, actor_mappings_stage1) group_name = self.group_name + "_stage1_comm" self.sync_broadcast_multi_threads( sorted_send_actors_stage1, actor_mappings_stage1, max_workers, requires_grad, group_name=group_name, stage2=False, filter_fn=filter_fn, param_group=param_group, dryrun=dryrun ) self.timers("stage1").stop() logger.info(f"finish stage1| {self.timers.log(names=['stage1'])}") # stage 2 self.timers("stage2").start() sorted_send_actors_stage2 = list(actor_mappings_stage2.keys()) max_workers = self._calculate_max_workers(sorted_send_actors_stage2, actor_mappings_stage2) group_name = self.group_name + "_stage2_comm" self.sync_broadcast_multi_threads( sorted_send_actors_stage2, actor_mappings_stage2, max_workers, requires_grad, group_name=group_name, stage2=True, filter_fn=filter_fn, param_group=param_group, dryrun=dryrun) self.timers("stage2").stop() logger.info(f"finish stage2| {self.timers.log(names=['stage2'])}") def split_sync_groups(self, send_actors, actor_mappings): groups = [] for send_actor in send_actors: recv_actors = actor_mappings[send_actor] rank_dict = [self.actor2rank[actor] for actor in [send_actor] + recv_actors] # gen groups placed = False for group in groups: if set(group["values"]).isdisjoint(rank_dict): group["keys"].append(send_actor) group["values"] = group["values"] + rank_dict placed = True break if not placed: groups.append({ "keys": [send_actor], "values": rank_dict.copy() }) total_elements = sum(len(group["keys"]) for group in groups) assert total_elements == len(send_actors), \ (f"needed total elements of groups {total_elements} == len of send_actors \ {len(send_actors)} in param sync.") for group in groups: assert len(group["values"]) == len(set(group["values"])), \ (f"the elements must be all different in group: {group['values']}") logger.info(f"split_sync_groups: {groups}") return [g["keys"] for g in groups] def _multi_thread_sync_for_tp_num_mapping_eq_1( self, send_actors_list:List, actor_mappings_list:List, requires_grad=None, filter_fn=None, param_group="default", dryrun=False ): assert len(send_actors_list) == 1 and len(actor_mappings_list) == 1 send_actors = send_actors_list[0] actor_mappings = actor_mappings_list[0] sorted_send_actors = self.sort_send_actors(actor_mappings, send_actors) max_workers = self._calculate_max_workers(sorted_send_actors, actor_mappings) src_pp_size = self.num_src_pipeline_stage groups = self.split_sync_groups(sorted_send_actors, actor_mappings) logger.info(f"Use {max_workers} workers for tp_num_mapping_eq_1 synchoronization, \ src_pp_size: {src_pp_size}, groups: {len(groups)}.") with ThreadPoolExecutor(max_workers=max_workers) as executor: for group in groups: t1 = time.time() futures = [] for send_actor in group: recv_actors = actor_mappings[send_actor] logger.info(f"Sending from {[self.actor2rank[send_actor]]} to {[self.actor2rank[actor] for actor in recv_actors]}.") if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group) if dryrun: continue futures.append(executor.submit( self.sync_broadcast, actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group )) else: for recv_actor in recv_actors: if dryrun: continue futures.append(executor.submit( self.sync_send_recv, send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group )) t2 = time.time() for _future in concurrent.futures.as_completed(futures): try: _future.result() except Exception as e: traceback.print_exc() raise RuntimeError(f"Parameter sync thread generated an exception: {e}") from e concurrent.futures.wait(futures) t3 = time.time() logger.info(f"sync for tp_num_mapping_eq_1, submit time(s):{(t2-t1)}, sync time(s):{(t3-t2)}") def _single_thread_sync(self, actor_mappings_list:List, requires_grad=None, filter_fn=None, param_group="default"): assert len(actor_mappings_list) == 1 actor_mappings = actor_mappings_list[0] for send_actor, recv_actors in actor_mappings.items(): if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST: actor_groups, finalized_group_name = self.create_broadcast_group(send_actor, recv_actors, param_group=param_group) self.sync_broadcast(actor_groups, finalized_group_name, requires_grad, filter_fn=filter_fn, param_group=param_group) else: for recv_actor in recv_actors: self.sync_send_recv(send_actor, recv_actor, requires_grad, filter_fn=filter_fn, param_group=param_group) def recover_synchronizer(self): refs = [] for actor, synchronizer in self.actor2synchronizer.items(): refs.append(actor.set_synchronizer.remote(synchronizer)) future.wait(refs) def reset_synchronizer(self): refs = [] for actor, _ in self.actor2synchronizer.items(): refs.append(actor.set_synchronizer.remote(None)) future.wait(refs) def sync(self, requires_grad=None, validate=False, dryrun=False): self.recover_synchronizer() self.check_and_setup_collective_group() self.check_and_fuse_lora(self._enable_lora, self.send_recv_actor_mappings) send_actors_list : List = [] actor_mappings_list : List = [] if self.concurrent_comm: if self.tp_num_mapping > 1: send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] self._multi_thread_sync_for_tp_num_mapping_gt_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad ) else: send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad ) else: actor_mappings_list = [self.send_recv_actor_mappings] self._single_thread_sync( actor_mappings_list, requires_grad=requires_grad ) assert len(actor_mappings_list) >= 1 self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings) self.validate_sync_results_parallel(actor_mappings_list, requires_grad, validate) self.check_and_destroy_collective_group() self.reset_synchronizer() logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}") class ParameterSyncGroupwithHEP(ParameterSyncGroup): """ParameterSyncGroup for Hyper Expert Parallel (HEP). Note that in HEP, EP size for routed experts is different from that for Megatorn-LM. For routed experts, the new EP size (we call it HEP size for clarification) = mpu.ep_size x mpu.tp_size, while Megatron-LM set the EP size as mpu.ep_size. However, the EP size of shared experts in HEP is equal to that in Megatron -LM (which is 1). In this case, routed experts treat TP and EP altogether as EP, and shared experts ignore EP just like other non-expert weights. Therefore, we manage seperate parameter sync groups for routed expert weigts and weights except routed experts in this class. """ def __init__(self, src_model, dst_model, group_name, frequency, error_signal): self.send_recv_actor_mappings_for_routed_experts = defaultdict(list) self.recv_send_actor_mappings_for_routed_experts = defaultdict(list) self._num_src_hyper_expert_parallel = None self._num_dst_hyper_expert_parallel = None self._actor2hep = {} self.sorted_send_actors_for_routed_experts = None super().__init__(src_model, dst_model, group_name, frequency, error_signal) def setup_rank_mapping(self): """ For now, we only allow parameter sync between two models that dst model tp size >= src model tp size && dst model ep size <= src model ep size There is no more restiction regarding PP and DP """ self.tp_num_mapping = self.num_dst_tensor_parallel // self.num_src_tensor_parallel self.ep_num_mapping = self.num_dst_expert_parallel / self.num_src_expert_parallel self.hep_num_mapping = self.num_dst_hyper_expert_parallel / self.num_src_hyper_expert_parallel assert self.tp_num_mapping >= 1, ( f"Currently, tensor parallel world size for training ({self.num_src_tensor_parallel}) should be" f"less or equal to tensor parallel world size for inference ({self.num_dst_tensor_parallel}) with HEP enabled." ) assert self.ep_num_mapping <= 1, ( f"Currently, expert parallel world size for training ({self.num_src_expert_parallel}) should be" f"greater or equal to expert parallel world size for inference ({self.num_dst_expert_parallel}) with HEP enabled." ) if self.dst_model.use_vllm_backend: if self.tp_num_mapping == 1: if self.ep_num_mapping == 1: self.build_rank_mapping() else: self.build_rank_mapping_for_ep() elif self.tp_num_mapping > 1: if self.hep_num_mapping == 1: self.build_rank_mapping_for_ep(add_recv_actor_fn=self.add_recv_actor_for_routed_experts) # only add all-gather actors self.build_rank_mapping_for_params_except_routed_expert() else: self.build_rank_mapping_for_ep(add_recv_actor_fn=self.empty_add_recv_actor) # only add all-gather actors self.build_rank_mapping_two_stage() else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" f"to smaller tp size ({self.num_dst_tensor_parallel}) currently." ) else: if self.ep_num_mapping == 1 and self.tp_num_mapping == 1: self.build_rank_mapping() elif self.hep_num_mapping == 1: # In this case, routed experts are mapped one by one, while params except routed experts are split by TP. self.build_rank_mapping_for_routed_experts() self.build_rank_mapping_for_params_except_routed_expert() else: # We do not support other cases for HEP. Please note that tp_num_mapping > 1 with ep_num_mapping = 1 is also unsupported. raise NotImplementedError( "ChatLearn does not support inequivalent EP x TP between training and inference with Hyper Expert Parallel (HEP) enabled and " f"inference model is an instance of `MegatronModule`. Your current setting is " f"EP{self.num_src_expert_parallel} TP{self.num_src_tensor_parallel} for training model `{self.src_model.name}` " f"and EP{self.num_dst_expert_parallel} TP{self.num_dst_tensor_parallel} for inference model `{self.dst_model.name}`." ) def build_rank_mapping_for_ep(self, add_recv_actor_fn=None): # setup rank mapping for src parameter and dst parameter # get rank for one src_model, without model replicas if add_recv_actor_fn is None: add_recv_actor_fn = self.add_recv_actor src_dp_ranks, dst_dp_ranks = self.get_src_and_dst_dp_ranks() if self._debug and (src_dp_ranks[0] is None or dst_dp_ranks is None): return assert len(src_dp_ranks[0]) % len(dst_dp_ranks[0]) == 0, \ f"src training model ranks should be times of dst ranks, but got {len(src_dp_ranks[0])} and {len(dst_dp_ranks[0])}" if self.src_model.colocate_with(self.dst_model) and self.num_src_tensor_parallel % 2 == 1: replica_rank_iter = cycle(reversed(src_dp_ranks)) else: replica_rank_iter = cycle(iter(src_dp_ranks)) logger.debug(f"src_dp_ranks: {src_dp_ranks}") logger.debug(f"dst_dp_ranks: {dst_dp_ranks}") assert self.num_src_pipeline_stage % self.num_dst_pipeline_stage == 0 def split_ranks_by_ep_and_tp_size(ranks, tp_size : int = 1, ep_size : int = 1): tp_and_ep_size = tp_size * ep_size return [[ranks[i:i + tp_size] for i in range(j, j + tp_and_ep_size, tp_size)] for j in range(0, len(ranks), tp_and_ep_size)] src_replica_ranks2offset = {} is_first_time_set_send_actors = True for dst_replica_ranks in dst_dp_ranks: src_replica_ranks = next(replica_rank_iter) if tuple(src_replica_ranks) not in src_replica_ranks2offset: src_replica_ranks2offset[tuple(src_replica_ranks)] = 0 is_first_time_set_send_actors = True else: is_first_time_set_send_actors = False src_replica_ranks_group = split_ranks_by_ep_and_tp_size(src_replica_ranks, self.num_src_tensor_parallel, self.num_src_expert_parallel) # Since dst replica is vllm and it doesn't have ep, the function will organize dst_replica_ranks_group as [pp[tp]] naturally. dst_replica_ranks_group = split_ranks_by_ep_and_tp_size(dst_replica_ranks, self.num_dst_tensor_parallel, self.num_dst_expert_parallel) if is_first_time_set_send_actors: self.set_send_actors_to_regroup_routed_experts(src_replica_ranks_group) self.add_routed_experts_regrouping_actor(self.src_model, src_replica_ranks_group) if add_recv_actor_fn is self.empty_add_recv_actor: continue pipe_map_interval = self.num_src_pipeline_stage // self.num_dst_pipeline_stage for i, src_ep_and_tp_group in enumerate(src_replica_ranks_group): j = i // pipe_map_interval first_src_tp_group = src_ep_and_tp_group[0] assert len(dst_replica_ranks_group[j][0]) % len(first_src_tp_group) == 0, ( "TP size of dst model should be times of src model, " f"but got {len(dst_replica_ranks_group[j][0])} and {len(first_src_tp_group)}" ) len_dst_div_src = len(dst_replica_ranks_group[j][0]) // len(first_src_tp_group) concated_src_tp_group = [] offset = src_replica_ranks2offset[tuple(src_replica_ranks)] # cycled concatenate src tp group to ensure len(concat_src_tp_group) == len(dst_replica_ranks_group[j][0]) for k in range(len_dst_div_src): concated_src_tp_group.extend(src_ep_and_tp_group[int((offset + k) % len(src_ep_and_tp_group))]) for src_rank, dst_rank in zip(concated_src_tp_group, dst_replica_ranks_group[j][0]): add_recv_actor_fn(src_rank, dst_rank) src_replica_ranks2offset[tuple(src_replica_ranks)] = int( (src_replica_ranks2offset[tuple(src_replica_ranks)] + len_dst_div_src) % len(src_ep_and_tp_group) ) if self._debug: def debug_msg_for_actor_mappings(actor_mapping): if actor_mapping is None: return for k, v_list in actor_mapping.items(): for v in v_list: logger.debug(f"actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") debug_msg_for_actor_mappings(self.send_recv_actor_mappings) debug_msg_for_actor_mappings(self.send_recv_actor_mappings_for_routed_experts) for regroup_actors in self.send_actors_to_regroup_routed_experts: count += 1 cat_str = "_".join(str(self.actor2rank[actor]) for actor in regroup_actors) logger.info(f"{self._comm_type_to_regroup_routed_experts} actors: {cat_str}") for k, v_list in self.send_recv_actor_mappings.items(): for v in v_list: logger.info(f"send_recv_actor_mappings: {self.actor2rank[k]} -> {self.actor2rank[v]}") def add_recv_actor_for_routed_experts(self, src_rank, dst_rank): src_actor = self.src_model.get_actor(src_rank) self.insert_actor2rank(src_actor, src_rank) self.insert_actor2model(src_actor, self.src_model) dst_actor = self.dst_model.get_actor(dst_rank) self.insert_actor2rank(dst_actor, dst_rank) self.insert_actor2model(dst_actor, self.dst_model) src_gpu = self.get_or_cache(src_actor, "get_visible_gpus") dst_gpu = self.get_or_cache(dst_actor, "get_visible_gpus") src_tp_rank = self.get_actor_tp_rank(src_actor) dst_tp_rank = self.get_actor_tp_rank(dst_actor) src_pp_rank = self.get_actor_pipe_rank(src_actor) dst_pp_rank = self.get_actor_pipe_rank(dst_actor) src_ep_rank = self.get_actor_ep_rank(src_actor) dst_ep_rank = self.get_actor_ep_rank(dst_actor) src_hep_rank = self.get_actor_hep_rank(src_actor) dst_hep_rank = self.get_actor_hep_rank(dst_actor) logger.debug(f"build rank mapping from {src_rank} to {dst_rank}, from gpu {src_gpu} to {dst_gpu}, " + f"from pipe_stage {src_pp_rank} to {dst_pp_rank}, " + f"from tp rank {src_tp_rank} to {dst_tp_rank}, " + f"from ep rank {src_ep_rank} to {dst_ep_rank}, " + f"from hep rank {src_hep_rank} to {dst_hep_rank}.") self.send_recv_actor_mappings_for_routed_experts[src_actor].append(dst_actor) self.recv_send_actor_mappings_for_routed_experts[dst_actor].append(src_actor) @property def num_src_hyper_expert_parallel(self): if self._num_src_hyper_expert_parallel is None: self._num_src_hyper_expert_parallel = future.get(self.src_model.replicas[0].all_actors[0].tensor_and_expert_model_parallel_size.remote()) return self._num_src_hyper_expert_parallel @property def num_dst_hyper_expert_parallel(self): if self._num_dst_hyper_expert_parallel is None: self._num_dst_hyper_expert_parallel = future.get(self.dst_model.replicas[0].all_actors[0].tensor_and_expert_model_parallel_size.remote()) return self._num_dst_hyper_expert_parallel def get_actor_hep_rank(self, actor): def inner_func(): return future.get(actor.tensor_and_expert_model_parallel_size.remote()) return utils.get_or_cache(self._actor2hep, actor, inner_func) def build_rank_mapping_for_routed_experts(self): self.build_rank_mapping(add_recv_actor_fn=self.add_recv_actor_for_routed_experts) def build_rank_mapping_for_params_except_routed_expert(self): self.build_rank_mapping_two_stage(add_recv_actor_fn=None) def routed_experts_filter(self, name_list: List[str]): filted_names = [name for name in name_list if 'mlp.experts' in name] return filted_names def params_except_routed_expert_filter(self, name_list: List[str]): filted_names = [name for name in name_list if 'mlp.experts' not in name] return filted_names def clear_cache(self, sorted_send_actors_list=None, rank_mapping_list=None): if sorted_send_actors_list is None: sorted_send_actors_list = [ self.sorted_send_actors, self.sorted_send_actors_stage2, self.send_actors_to_regroup_routed_experts, self.sorted_send_actors_for_routed_experts ] if rank_mapping_list is None: rank_mapping_list = [ self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2, self.send_recv_actor_mappings_for_routed_experts ] self._clear_sync_send_recv_parameters(rank_mapping_list) self._clear_send_recv_param_names() self._clear_sorted_send_actors(sorted_send_actors_list) def warmup_groups(self): def warmup_tasks_func(task): actors = task.actors group = task.group refs = [] refs.append(actors[0].broadcast_dummy_tensor_send.remote(0, group)) for actor in actors[1:]: refs.append(actor.broadcast_dummy_tensor_recv.remote(0, group)) future.wait(refs) tasks = [] actors_set = set() for group_name, actors in self.groups2actors.items(): # filter actors if the same collective ring actor_ids = [self.actor2rank[actor] for actor in actors] key = tuple(sorted(actor_ids)) if key not in actors_set: tasks.append(CollectiveTask(actors, group_name)) parallel_execute_collective_tasks(tasks, warmup_tasks_func) def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False, dryrun=False): self.check_and_setup_collective_group() send_actors_list : List = [ self.sorted_send_actors, self.sorted_send_actors_stage2 ] actor_mappings_list : List = [ self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2, self.send_actors_to_regroup_routed_experts, ] self.check_and_fuse_lora(self._enable_lora, actor_mappings_list) if self.concurrent_comm: assert self.dst_model.use_vllm_backend max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts) if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER: # allgather routed experts only self.sync_allgather_multi_threads( [self.send_actors_to_regroup_routed_experts], max_workers=max_workers, requires_grad=requires_grad, group_name=self.group_name + "_allgather", filter_fn=self.routed_experts_filter) elif self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL: if not dryrun: logger.info("start to alltoall router experts. ") start_time = time.time() # alltoall routed experts only self.sync_alltoall_multi_threads( [self.send_actors_to_regroup_routed_experts], max_workers=max_workers, requires_grad=requires_grad, filter_fn=self.routed_experts_filter) logger.info("complete to alltoall router experts using {time.time()-start_time:.2f} seconds ") # sync everything to inference model if self.tp_num_mapping == 1: logger.info("start to sync all moe experts") send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad, filter_fn=None, param_group="default", dryrun=dryrun ) logger.info("complete to sync all moe experts") elif self.tp_num_mapping > 1: # First, synchronize routed experts. logger.info("start to sync routed expert weights.") start_time = time.time() self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun) logger.info(f"complete to sync routed expert weights. [stage1-1] using {time.time()-start_time:.2f} seconds") self.clear_cache( sorted_send_actors_list = [ self.send_actors_to_regroup_routed_experts, self.sorted_send_actors_for_routed_experts ], rank_mapping_list=[ self.send_recv_actor_mappings_for_routed_experts ] ) # Then, synchronize parameters except routed experts logger.info("start to sync parameters except routed eperts.") self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun) logger.info("complete to sync parameters except routed experts.") self.reset_synchronizer() self.clear_cache( sorted_send_actors_list = [ self.sorted_send_actors, self.sorted_send_actors_stage2, ], rank_mapping_list = [ self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2 ] ) else: raise NotImplementedError( f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})" f"to smaller tp size ({self.num_dst_tensor_parallel}) currently." ) else: raise NotImplementedError( "ChatLearn supports only concurrent_comm for training models with HEP enabled and inference with vLLM" ) self.check_and_unfuse_lora(self._enable_lora, actor_mappings_list) self.validate_sync_results_parallel(actor_mappings_list, requires_grad, validate) self.check_and_destroy_collective_group() self.reset_synchronizer() logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}") def _synchronize_routed_experts(self, requires_grad=None, validate=False, dryrun=False): self.check_and_setup_collective_group() self.check_and_fuse_lora(self._enable_lora, self.send_recv_actor_mappings_for_routed_experts) send_actors_list : List = [] actor_mappings_list : List = [] if self.concurrent_comm: send_actors_list = [self.sorted_send_actors_for_routed_experts] actor_mappings_list = [self.send_recv_actor_mappings_for_routed_experts] self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad, filter_fn=self.routed_experts_filter, param_group="routed", dryrun=dryrun, ) else: actor_mappings_list = [self.send_recv_actor_mappings_for_routed_experts] self._single_thread_sync( self.send_recv_actor_mappings_for_routed_experts, requires_grad=requires_grad, filter_fn=self.routed_experts_filter, param_group="routed", ) assert len(actor_mappings_list) >= 1 self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings_for_routed_experts) self.validate_sync_results_parallel( actor_mappings_list, requires_grad, validate, filter_fn=self.routed_experts_filter, param_group="routed" ) self.check_and_destroy_collective_group() logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}") def _synchronize_params_except_routed_experts(self, requires_grad=None, validate=False, dryrun=False): self.check_and_setup_collective_group() self.check_and_fuse_lora(self._enable_lora, self.send_recv_actor_mappings) send_actors_list : List = [] actor_mappings_list : List = [] if self.concurrent_comm: if self.tp_num_mapping > 1: send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2] actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2] self._multi_thread_sync_for_tp_num_mapping_gt_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad, filter_fn=self.params_except_routed_expert_filter, param_group="except_routed", dryrun=dryrun ) else: send_actors_list = [self.sorted_send_actors] actor_mappings_list = [self.send_recv_actor_mappings] self._multi_thread_sync_for_tp_num_mapping_eq_1( send_actors_list, actor_mappings_list, requires_grad=requires_grad, filter_fn=self.params_except_routed_expert_filter, param_group="except_routed" ) else: actor_mappings_list = [self.send_recv_actor_mappings] self._single_thread_sync( actor_mappings_list, requires_grad=requires_grad, filter_fn=self.params_except_routed_expert_filter, param_group="except_routed" ) self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings) self.validate_sync_results_parallel( actor_mappings_list, requires_grad, validate, filter_fn=self.params_except_routed_expert_filter, param_group="except_routed" ) self.check_and_destroy_collective_group() logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}") def sync(self, requires_grad=None, validate=False, dryrun=False): if self.dst_model.use_vllm_backend: self.recover_synchronizer() self._synchronize_all_moe_parameters(requires_grad=requires_grad, validate=validate, dryrun=dryrun) else: if self.ep_num_mapping == 1 and self.tp_num_mapping == 1: # synchronization is the same as base class when applying Qwen + Qwen super().sync(requires_grad, validate) return self.recover_synchronizer() # First, synchronize routed experts. self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate) self.clear_cache( sorted_send_actors_list = [ self.send_actors_to_regroup_routed_experts, self.sorted_send_actors_for_routed_experts ], rank_mapping_list=[ self.send_recv_actor_mappings_for_routed_experts ] ) # Then, synchronize parameters except routed experts self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate) self.reset_synchronizer() self.clear_cache( sorted_send_actors_list = [ self.sorted_send_actors, self.sorted_send_actors_stage2, ], rank_mapping_list = [ self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2 ] )