in chatlearn/synchronizer/parameter_sync.py [0:0]
def __init__(self, src_model, dst_model, group_name, frequency, error_signal):
self.src_model = src_model
self.dst_model = dst_model
self.synchronizer = get_synchronizer(src_model, dst_model)
self.group_name = group_name
self.error_signal = error_signal
self.send_recv_actor_mappings = defaultdict(list)
self.recv_send_actor_mappings = defaultdict(list)
self.send_recv_actor_mappings_stage2 = defaultdict(list)
self.recv_send_actor_mappings_stage2 = defaultdict(list)
self.actor2rank = {}
self.actor2model = {}
self._debug = get_args().runtime_args.debug
self._num_src_pipeline_stage = None
self._num_dst_pipeline_stage = None
self._num_src_expert_parallel = None
self._num_dst_expert_parallel = None
self._num_src_tensor_parallel = None
self._num_dst_tensor_parallel = None
self._send_recv_param_names = {}
self._actor2pipe = {}
self._actor2tp = {}
self._actor2ep = {}
self._actor2dp = {}
self._comm_type = get_args().runtime_args.param_sync_comm_type
if src_model.colocate_with(dst_model) and self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
if self.num_src_tensor_parallel % 2 == 1 and self.num_dst_tensor_parallel % 2 == 1:
logger.warning("Only support PARAM_SYNC_COMM_TYPE.BROADCAST when TP SIZE is even number, use P2P instead")
self._comm_type = PARAM_SYNC_COMM_TYPE.P2P
self.concurrent_comm = get_args().runtime_args.concurrent_comm
self._enable_lora = self.src_model.module_args.lora.enable_lora
# sync every n episodes, n = 0 for no param sync
self._frequency = frequency
self._free_sync_collective_group = get_args().runtime_args.free_sync_collective_group
self._is_collective_group_created = True
self.collective_groups = []
self.groups2actors = {} # group_name -> []actors
self.src_dp_size = future.get(self.src_model.replicas[0].all_actors[0].get_data_parallel_size.remote())
self.send_actors_to_regroup_routed_experts = None
self._comm_type_to_regroup_routed_experts = get_args().runtime_args.routed_expert_regrouping_comm_type
assert self._comm_type_to_regroup_routed_experts in \
[ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER, ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL], \
f"Only support 'allgather' or 'alltoall' for routed expert regrouping, while {self._comm_type_to_regroup_routed_experts}"
if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL:
if self.num_dst_tensor_parallel * self.num_dst_expert_parallel != self.num_src_tensor_parallel * self.num_src_expert_parallel:
logger.info("Only support ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL when src tp eqs dst tp, use 'allgather' instead.")
self._comm_type_to_regroup_routed_experts = ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER
logger.info(f"Set ROUTED_EXPERT_REGROUPING_COMM_TYPE = {self._comm_type_to_regroup_routed_experts}.")
self.sorted_send_actors = None
self.sorted_send_actors_stage2 = None
self.actor2synchronizer = {}
self.setup_collective_group()
self.setup_rank_mapping()
self.timers = Timers()