graphlearn_torch/python/distributed/event_loop.py (59 lines of code) (raw):
# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import asyncio
import logging
from threading import Thread, BoundedSemaphore
import torch
def wrap_torch_future(f: torch.futures.Future) -> asyncio.futures.Future:
r""" Convert a torch future to a standard asyncio future.
"""
loop = asyncio.get_event_loop()
aio_future = loop.create_future()
def on_done(*_):
try:
result = f.wait()
except Exception as e:
loop.call_soon_threadsafe(aio_future.set_exception, e)
else:
loop.call_soon_threadsafe(aio_future.set_result, result)
f.add_done_callback(on_done)
return aio_future
class ConcurrentEventLoop(object):
r""" Concurrent event loop context.
Args:
concurrency: max processing concurrency.
"""
def __init__(self, concurrency):
self._concurrency = concurrency
self._sem = BoundedSemaphore(concurrency)
self._loop = asyncio.new_event_loop()
self._runner_t = Thread(target=self._run_loop)
self._runner_t.daemon = True
def start_loop(self):
if not self._runner_t.is_alive():
self._runner_t.start()
def shutdown_loop(self):
self.wait_all()
if self._runner_t.is_alive():
self._loop.stop()
self._runner_t.join(timeout=1)
def wait_all(self):
r""" Wait all pending tasks to be finished.
"""
for _ in range(self._concurrency):
self._sem.acquire()
for _ in range(self._concurrency):
self._sem.release()
def add_task(self, coro, callback=None):
r""" Add an asynchronized coroutine task to run.
Args:
coro: the async coroutine func.
callback: the callback func applied on the returned results
after the coroutine task is finished.
Note that any results returned by `callback` func will be ignored,
so it is preferable to handle all in your `callback` func and do
not return any results.
"""
def on_done(f: asyncio.futures.Future):
try:
res = f.result()
if callback is not None:
callback(res)
except Exception as e:
logging.error("coroutine task failed: %s", e)
self._sem.release()
self._sem.acquire()
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
fut.add_done_callback(on_done)
def run_task(self, coro):
r""" Run a coroutine task synchronously.
"""
with self._sem:
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
return fut.result()
def _run_loop(self):
self._loop.run_forever()