in chatlearn/synchronizer/parameter_sync.py [0:0]
def _synchronize_params_except_routed_experts(self, requires_grad=None, validate=False, dryrun=False):
self.check_and_setup_collective_group()
self.check_and_fuse_lora(self._enable_lora, self.send_recv_actor_mappings)
send_actors_list : List = []
actor_mappings_list : List = []
if self.concurrent_comm:
if self.tp_num_mapping > 1:
send_actors_list = [self.sorted_send_actors, self.sorted_send_actors_stage2]
actor_mappings_list = [self.send_recv_actor_mappings, self.send_recv_actor_mappings_stage2]
self._multi_thread_sync_for_tp_num_mapping_gt_1(
send_actors_list,
actor_mappings_list,
requires_grad=requires_grad,
filter_fn=self.params_except_routed_expert_filter,
param_group="except_routed",
dryrun=dryrun
)
else:
send_actors_list = [self.sorted_send_actors]
actor_mappings_list = [self.send_recv_actor_mappings]
self._multi_thread_sync_for_tp_num_mapping_eq_1(
send_actors_list,
actor_mappings_list,
requires_grad=requires_grad,
filter_fn=self.params_except_routed_expert_filter,
param_group="except_routed"
)
else:
actor_mappings_list = [self.send_recv_actor_mappings]
self._single_thread_sync(
actor_mappings_list,
requires_grad=requires_grad,
filter_fn=self.params_except_routed_expert_filter,
param_group="except_routed"
)
self.check_and_unfuse_lora(self._enable_lora, self.send_recv_actor_mappings)
self.validate_sync_results_parallel(
actor_mappings_list,
requires_grad,
validate,
filter_fn=self.params_except_routed_expert_filter,
param_group="except_routed"
)
self.check_and_destroy_collective_group()
logger.info(f"Group {self.group_name} sync all parameters done, comm_type {self._comm_type}")