in chatlearn/synchronizer/parameter_sync.py [0:0]
def _synchronize_all_moe_parameters(self, requires_grad=None, validate=False, dryrun=False):
self.check_and_setup_collective_group()
send_actors_list : List = [
self.sorted_send_actors,
self.sorted_send_actors_stage2
]
actor_mappings_list : List = [
self.send_recv_actor_mappings,
self.send_recv_actor_mappings_stage2,
self.send_actors_to_regroup_routed_experts,
]
self.check_and_fuse_lora(self._enable_lora, actor_mappings_list)
if self.concurrent_comm:
assert self.dst_model.use_vllm_backend
max_workers = self._calculate_max_workers(self.send_actors_to_regroup_routed_experts)
if self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLGATHER:
# allgather routed experts only
self.sync_allgather_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=max_workers,
requires_grad=requires_grad,
group_name=self.group_name + "_allgather",
filter_fn=self.routed_experts_filter)
elif self._comm_type_to_regroup_routed_experts == ROUTED_EXPERT_REGROUPING_COMM_TYPE.ALLTOALL:
if not dryrun:
logger.info("start to alltoall router experts. ")
start_time = time.time()
# alltoall routed experts only
self.sync_alltoall_multi_threads(
[self.send_actors_to_regroup_routed_experts],
max_workers=max_workers,
requires_grad=requires_grad,
filter_fn=self.routed_experts_filter)
logger.info("complete to alltoall router experts using {time.time()-start_time:.2f} seconds ")
# sync everything to inference model
if self.tp_num_mapping == 1:
logger.info("start to sync all moe experts")
send_actors_list = [self.sorted_send_actors]
actor_mappings_list = [self.send_recv_actor_mappings]
self._multi_thread_sync_for_tp_num_mapping_eq_1(
send_actors_list,
actor_mappings_list,
requires_grad=requires_grad,
filter_fn=None,
param_group="default",
dryrun=dryrun
)
logger.info("complete to sync all moe experts")
elif self.tp_num_mapping > 1:
# First, synchronize routed experts.
logger.info("start to sync routed expert weights.")
start_time = time.time()
self._synchronize_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun)
logger.info(f"complete to sync routed expert weights. [stage1-1] using {time.time()-start_time:.2f} seconds")
self.clear_cache(
sorted_send_actors_list = [
self.send_actors_to_regroup_routed_experts,
self.sorted_send_actors_for_routed_experts
],
rank_mapping_list=[
self.send_recv_actor_mappings_for_routed_experts
]
)
# Then, synchronize parameters except routed experts
logger.info("start to sync parameters except routed eperts.")
self._synchronize_params_except_routed_experts(requires_grad=requires_grad, validate=validate, dryrun=dryrun)
logger.info("complete to sync parameters except routed experts.")
self.reset_synchronizer()
self.clear_cache(
sorted_send_actors_list = [
self.sorted_send_actors,
self.sorted_send_actors_stage2,
],
rank_mapping_list = [
self.send_recv_actor_mappings,
self.send_recv_actor_mappings_stage2
]
)
else:
raise NotImplementedError(
f"ChatLearn does not support synchronizing from larger tp size ({self.num_src_tensor_parallel})"
f"to smaller tp size ({self.num_dst_tensor_parallel}) currently."
)
else:
raise NotImplementedError(
"ChatLearn supports only concurrent_comm for training models with HEP enabled and inference with vLLM"
)
self.check_and_unfuse_lora(self._enable_lora, actor_mappings_list)
self.validate_sync_results_parallel(actor_mappings_list, requires_grad, validate)
self.check_and_destroy_collective_group()
self.reset_synchronizer()
logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}")