graphlearn_torch/python/distributed/event_loop.py (59 lines of code) (raw):

# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import asyncio import logging from threading import Thread, BoundedSemaphore import torch def wrap_torch_future(f: torch.futures.Future) -> asyncio.futures.Future: r""" Convert a torch future to a standard asyncio future. """ loop = asyncio.get_event_loop() aio_future = loop.create_future() def on_done(*_): try: result = f.wait() except Exception as e: loop.call_soon_threadsafe(aio_future.set_exception, e) else: loop.call_soon_threadsafe(aio_future.set_result, result) f.add_done_callback(on_done) return aio_future class ConcurrentEventLoop(object): r""" Concurrent event loop context. Args: concurrency: max processing concurrency. """ def __init__(self, concurrency): self._concurrency = concurrency self._sem = BoundedSemaphore(concurrency) self._loop = asyncio.new_event_loop() self._runner_t = Thread(target=self._run_loop) self._runner_t.daemon = True def start_loop(self): if not self._runner_t.is_alive(): self._runner_t.start() def shutdown_loop(self): self.wait_all() if self._runner_t.is_alive(): self._loop.stop() self._runner_t.join(timeout=1) def wait_all(self): r""" Wait all pending tasks to be finished. """ for _ in range(self._concurrency): self._sem.acquire() for _ in range(self._concurrency): self._sem.release() def add_task(self, coro, callback=None): r""" Add an asynchronized coroutine task to run. Args: coro: the async coroutine func. callback: the callback func applied on the returned results after the coroutine task is finished. Note that any results returned by `callback` func will be ignored, so it is preferable to handle all in your `callback` func and do not return any results. """ def on_done(f: asyncio.futures.Future): try: res = f.result() if callback is not None: callback(res) except Exception as e: logging.error("coroutine task failed: %s", e) self._sem.release() self._sem.acquire() fut = asyncio.run_coroutine_threadsafe(coro, self._loop) fut.add_done_callback(on_done) def run_task(self, coro): r""" Run a coroutine task synchronously. """ with self._sem: fut = asyncio.run_coroutine_threadsafe(coro, self._loop) return fut.result() def _run_loop(self): self._loop.run_forever()