def run()

in src/datatrove/executor/ray.py [0:0]


    def run(self):
        """
        Run the pipeline for each rank using Ray tasks.
        """

        check_required_dependencies("ray", ["ray"])
        import ray

        # 1) If there is a depends=, ensure it has run and is finished
        if self.depends:
            logger.info(f'Launching dependency job "{self.depends}"')
            self.depends.run()

        # 3) Check if all tasks are already completed
        incomplete_ranks = self.get_incomplete_ranks(range(self.world_size))
        if not incomplete_ranks:
            logger.info(f"All {self.world_size} tasks appear to be completed already. Nothing to run.")
            return

        logger.info(f"Will run pipeline on {len(incomplete_ranks)} incomplete ranks out of {self.world_size} total.")

        # 4) Save executor JSON
        self.save_executor_as_json()

        executor_ref = ray.put(self)

        # 5) Define resource requirements for this pipeline's tasks
        remote_options = {
            "num_cpus": self.cpus_per_task,
            "num_gpus": 0,
            "memory": int(self.mem_per_cpu_gb * self.cpus_per_task * 1024 * 1024 * 1024),
        }
        if self.ray_remote_kwargs:
            remote_options.update(self.ray_remote_kwargs)

        # 6) Dispatch Ray tasks
        MAX_CONCURRENT_TASKS = self.workers
        ranks_per_jobs = [
            incomplete_ranks[i : i + self.tasks_per_job] for i in range(0, len(incomplete_ranks), self.tasks_per_job)
        ]
        unfinished = []
        total_tasks = len(ranks_per_jobs)
        completed = 0

        ray_remote_func = ray.remote(**remote_options)(run_for_rank)

        # 7) Keep tasks start_time
        task_start_times = {}
        for _ in range(min(MAX_CONCURRENT_TASKS, len(ranks_per_jobs))):
            ranks_to_submit = ranks_per_jobs.pop(0)
            task = ray_remote_func.remote(executor_ref, ranks_to_submit)
            unfinished.append(task)
            task_start_times[task] = time.time()

        # 7) Wait for the tasks to finish, merging them as they complete.
        while unfinished:
            finished, unfinished = ray.wait(unfinished, num_returns=len(unfinished), timeout=10)
            for task in finished:
                # Remove task from task_start_times
                del task_start_times[task]
                # Remove task itself
                del task

            try:
                results = ray.get(finished)
                for _ in results:
                    completed += 1
            except Exception as e:
                logger.exception(f"Error processing rank: {e}")

            # If we have more ranks left to process and we haven't hit the max
            # number of concurrent tasks, add tasks to the unfinished queue.
            while ranks_per_jobs and len(unfinished) < MAX_CONCURRENT_TASKS:
                ranks_to_submit = ranks_per_jobs.pop(0)
                task = ray_remote_func.remote(executor_ref, ranks_to_submit)
                unfinished.append(task)
                task_start_times[task] = time.time()

            # Finally remove tasks that run for more than self.timeout seconds
            if self.time:
                for task in unfinished:
                    if time.time() - task_start_times[task] > self.time:
                        # No mercy :) -> should call SIGKILL
                        ray.kill(task, force=True)
                        del task_start_times[task]
                        unfinished.remove(task)
                        logger.warning(f"Task {task} timed out after {self.time} seconds and was killed.")
        logger.info("All Ray tasks have finished.")

        # 8) Merge stats of all ranks
        if completed == total_tasks:
            total_stats = PipelineStats()
            for rank in range(self.world_size):
                with self.logging_dir.open(f"stats/{rank:05d}.json", "r") as f:
                    total_stats += PipelineStats.from_json(json.load(f))
            with self.logging_dir.open("stats.json", "wt") as statsfile:
                total_stats.save_to_disk(statsfile)
            logger.success(total_stats.get_repr(f"All {completed}/{total_tasks} tasks."))
        else:
            logger.warning(f"Only {completed}/{total_tasks} tasks completed.")