in src/datatrove/executor/local.py [0:0]
def run(self):
"""
This method is responsible for correctly invoking `self._run_for_rank` for each task that is to be run.
On a LocalPipelineExecutor, this method will spawn a multiprocess pool if workers != 1.
Otherwise, ranks will be run sequentially in a loop.
Returns:
"""
assert not self.depends or (isinstance(self.depends, LocalPipelineExecutor)), (
"depends= must be a LocalPipelineExecutor"
)
if self.depends:
# take care of launching any unlaunched dependencies
if not self.depends._launched:
logger.info(f'Launching dependency job "{self.depends}"')
self.depends.run()
while (incomplete := len(self.depends.get_incomplete_ranks())) > 0:
logger.info(f"Dependency job still has {incomplete}/{self.depends.world_size} tasks. Waiting...")
time.sleep(2 * 60)
self._launched = True
if all(map(self.is_rank_completed, range(self.local_rank_offset, self.local_rank_offset + self.local_tasks))):
logger.info(f"Not doing anything as all {self.local_tasks} tasks have already been completed.")
return
self.save_executor_as_json()
mg = multiprocess.Manager()
ranks_q = mg.Queue()
for i in range(self.workers):
ranks_q.put(i)
ranks_to_run = self.get_incomplete_ranks(
range(self.local_rank_offset, self.local_rank_offset + self.local_tasks)
)
if (skipped := self.local_tasks - len(ranks_to_run)) > 0:
logger.info(f"Skipping {skipped} already completed tasks")
if self.workers == 1:
pipeline = self.pipeline
stats = []
for rank in ranks_to_run:
self.pipeline = deepcopy(pipeline)
stats.append(self._launch_run_for_rank(rank, ranks_q))
else:
completed_counter = mg.Value("i", skipped)
completed_lock = mg.Lock()
ctx = multiprocess.get_context(self.start_method)
with ctx.Pool(self.workers) as pool:
stats = list(
pool.imap_unordered(
partial(
self._launch_run_for_rank,
ranks_q=ranks_q,
completed=completed_counter,
completed_lock=completed_lock,
),
ranks_to_run,
)
)
# merged stats
stats = sum(stats, start=PipelineStats())
with self.logging_dir.open("stats.json", "wt") as statsfile:
stats.save_to_disk(statsfile)
logger.success(stats.get_repr(f"All {self.local_tasks} tasks"))
return stats