in chatlearn/synchronizer/parameter_sync.py [0:0]
def sync_broadcast_multi_threads(
self, sorted_send_actors, send_recv_actor_mappings, max_workers=1, requires_grad=None,
group_name=None, stage2=False, filter_fn=None, param_group="default", dryrun=False):
if stage2:
thread_group = []
for send_actor in sorted_send_actors:
recv_actors = send_recv_actor_mappings[send_actor]
for recv_actor in recv_actors:
thread_group.append((send_actor, recv_actor))
actor_groups_to_sync = []
for group in thread_group:
new_actor_group_flag = True
for idx, actor_groups in enumerate(actor_groups_to_sync):
in_actor_group = False
for actor_group in actor_groups:
if group[0] in actor_group or group[1] in actor_group:
in_actor_group = True
if not in_actor_group:
new_actor_group_flag = False
actor_groups_to_sync[idx].append(group) #pylint: disable=unnecessary-list-index-lookup
break
if new_actor_group_flag or not actor_groups_to_sync:
actor_groups_to_sync.append([group])
for group_idx, actor_groups in enumerate(actor_groups_to_sync):
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
self.sync_broadcast_second_stage(
f"{group_name}_{group_idx}",
actor_groups,
requires_grad,
filter_fn,
param_group,
dryrun=dryrun
)
else:
raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.")
else:
max_workers = len(sorted_send_actors)
logger.info(f"Use {max_workers} workers for first_stage broadcasting.")
with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for send_actor in sorted_send_actors:
recv_actors = send_recv_actor_mappings[send_actor]
if self._comm_type == PARAM_SYNC_COMM_TYPE.BROADCAST:
actor_groups, finalized_group_name = self.create_broadcast_group(
send_actor, recv_actors, group_name=group_name, param_group=param_group
)
if not dryrun:
futures.append(executor.submit(
self.sync_broadcast_two_stage, actor_groups, finalized_group_name, requires_grad, stage2, filter_fn, param_group
))
else:
raise RuntimeError("support p2p only for scenes that trainer_tp not equal to inference_tp.")
for _future in concurrent.futures.as_completed(futures):
try:
_future.result()
except Exception as e:
traceback.print_exc()
raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from
concurrent.futures.wait(futures)