def sync_broadcast_multi_threads()

in chatlearn/synchronizer/parameter_sync.py [0:0]


    def sync_broadcast_multi_threads(
        self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None,
        group_name=None, stage2=False, filter_fn=None, param_group="default", dryrun=False):

        if stage2:
            thread_group = []
            for send_actor in sorted_send_actors:
                recv_actors = send_recv_actor_mappings[send_actor]
                for recv_actor in recv_actors:
                    thread_group.append((send_actor, recv_actor))
            actor_groups_to_sync = []
            for group in thread_group:
                new_actor_group_flag = True
                for idx, actor_groups in enumerate(actor_groups_to_sync):
                    in_actor_group = False
                    for actor_group in actor_groups:
                        if group[0] in actor_group or group[1] in actor_group:
                            in_actor_group = True
                    if not in_actor_group:
                        new_actor_group_flag = False
                        actor_groups_to_sync[idx].append(group) #pylint: disable=unnecessary-list-index-lookup
                        break
                if new_actor_group_flag or not actor_groups_to_sync:
                    actor_groups_to_sync.append([group])

            for group_idx, actor_groups in enumerate(actor_groups_to_sync):
                if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
                    self.sync_broadcast_second_stage(
                        f"{group_name}_{group_idx}",
                        actor_groups,
                        requires_grad,
                        filter_fn,
                        param_group,
                        dryrun=dryrun
                    )
                else:
                    raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.")
        else:
            max_workers = len(sorted_send_actors)
            logger.info(f"Use {max_workers} workers for first_stage broadcasting.")
            with ThreadPoolExecutor(max_workers=max_workers) as executor:
                futures = []
                for send_actor in sorted_send_actors:
                    recv_actors = send_recv_actor_mappings[send_actor]
                    if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
                        actor_groups, finalized_group_name = self.create_broadcast_group(
                            send_actor, recv_actors, group_name=group_name, param_group=param_group
                        )
                        if not dryrun:
                            futures.append(executor.submit(
                                self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group
                            ))
                    else:
                        raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.")
                for _future in concurrent.futures.as_completed(futures):
                    try:
                        _future.result()
                    except Exception as e:
                        traceback.print_exc()
                        raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from
                concurrent.futures.wait(futures)