client/commands/async_server_connection.py (212 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import abc
import asyncio
import logging
import sys
from pathlib import Path
from typing import AsyncIterator, Tuple, Optional, List
if sys.version_info >= (3, 7):
from contextlib import asynccontextmanager
else:
from async_generator import asynccontextmanager
LOG: logging.Logger = logging.getLogger(__name__)
class ConnectionFailure(Exception):
pass
class BytesReader(abc.ABC):
"""
This class defines the basic interface for async I/O input channel.
"""
@abc.abstractmethod
async def read_until(self, separator: bytes = b"\n") -> bytes:
"""
Read data from the stream until `separator` is found.
If EOF is reached before the complete separator is found, raise
`asyncio.IncompleteReadError`.
"""
raise NotImplementedError
@abc.abstractmethod
async def read_exactly(self, count: int) -> bytes:
"""
Read exactly `count` bytes.
If EOF is reached before the complete separator is found, raise
`asyncio.IncompleteReadError`.
"""
raise NotImplementedError
async def readline(self) -> bytes:
"""
Read one line, where "line" is a sequence of bytes ending with '\n'.
If EOF is received and '\n' was not found, the method returns partially
read data.
"""
try:
return await self.read_until(b"\n")
except asyncio.IncompleteReadError as error:
return error.partial
class BytesWriter(abc.ABC):
"""
This class defines the basic interface for async I/O output channel.
"""
@abc.abstractmethod
async def write(self, data: bytes) -> None:
"""
The method attempts to write the data to the underlying channel and
flushes immediately.
"""
raise NotImplementedError
@abc.abstractmethod
async def close(self) -> None:
"""
The method closes the underlying channel and wait until the channel is
fully closed.
"""
raise NotImplementedError
class TextReader:
"""
An adapter for `BytesReader` that decodes everything it reads immediately
from bytes to string. In other words, it tries to expose the same interfaces
as `BytesReader` except it operates on strings rather than bytestrings.
"""
bytes_reader: BytesReader
encoding: str
def __init__(self, bytes_reader: BytesReader, encoding: str = "utf-8") -> None:
self.bytes_reader = bytes_reader
self.encoding = encoding
async def read_until(self, separator: str = "\n") -> str:
separator_bytes = separator.encode(self.encoding)
result_bytes = await self.bytes_reader.read_until(separator_bytes)
return result_bytes.decode(self.encoding)
async def read_exactly(self, count: int) -> str:
result_bytes = await self.bytes_reader.read_exactly(count)
return result_bytes.decode(self.encoding)
async def readline(self) -> str:
result_bytes = await self.bytes_reader.readline()
return result_bytes.decode(self.encoding)
class TextWriter:
"""
An adapter for `BytesWriter` that encodes everything it writes immediately
from string to bytes. In other words, it tries to expose the same interfaces
as `BytesWriter` except it operates on strings rather than bytestrings.
"""
bytes_writer: BytesWriter
encoding: str
def __init__(self, bytes_writer: BytesWriter, encoding: str = "utf-8") -> None:
self.bytes_writer = bytes_writer
self.encoding = encoding
async def write(self, data: str) -> None:
data_bytes = data.encode(self.encoding)
await self.bytes_writer.write(data_bytes)
class MemoryBytesReader(BytesReader):
"""
An implementation of `BytesReader` based on a given in-memory byte string.
"""
_data: bytes
_cursor: int
def __init__(self, data: bytes) -> None:
self._data = data
self._cursor = 0
async def read_until(self, separator: bytes = b"\n") -> bytes:
result = bytearray()
start_index = self._cursor
end_index = len(self._data)
for index in range(start_index, end_index):
byte = self._data[index]
result.append(byte)
if result.endswith(separator):
self._cursor = index + 1
return bytes(result)
self._cursor = end_index
raise asyncio.IncompleteReadError(bytes(result), None)
async def read_exactly(self, count: int) -> bytes:
old_cursor = self._cursor
new_cursor = self._cursor + count
data_size = len(self._data)
if new_cursor <= data_size:
self._cursor = new_cursor
return self._data[old_cursor:new_cursor]
else:
self._cursor = data_size
raise asyncio.IncompleteReadError(self._data[old_cursor:], count)
def reset(self) -> None:
self._cursor = 0
def create_memory_text_reader(data: str, encoding: str = "utf-8") -> TextReader:
return TextReader(MemoryBytesReader(data.encode(encoding)), encoding)
class MemoryBytesWriter(BytesWriter):
_items: List[bytes]
def __init__(self) -> None:
self._items = []
async def write(self, data: bytes) -> None:
self._items.append(data)
async def close(self) -> None:
pass
def items(self) -> List[bytes]:
return self._items
def create_memory_text_writer(encoding: str = "utf-8") -> TextWriter:
return TextWriter(MemoryBytesWriter())
class StreamBytesReader(BytesReader):
"""
An implementation of `BytesReader` based on `asyncio.StreamReader`.
"""
stream_reader: asyncio.StreamReader
def __init__(self, stream_reader: asyncio.StreamReader) -> None:
self.stream_reader = stream_reader
async def read_until(self, separator: bytes = b"\n") -> bytes:
# StreamReader.readuntil() may raise when its internal buffer cannot hold
# all the input data. We need to explicitly handle the raised exceptions
# by "parking" all partial-read results in memory.
chunks = []
while True:
try:
chunk = await self.stream_reader.readuntil(separator)
chunks.append(chunk)
break
except asyncio.LimitOverrunError as error:
chunk = await self.stream_reader.readexactly(error.consumed)
chunks.append(chunk)
return b"".join(chunks)
async def read_exactly(self, count: int) -> bytes:
return await self.stream_reader.readexactly(count)
class StreamBytesWriter(BytesWriter):
"""
An implementation of `BytesWriter` based on `asyncio.StreamWriter`.
"""
stream_writer: asyncio.StreamWriter
def __init__(self, stream_writer: asyncio.StreamWriter) -> None:
self.stream_writer = stream_writer
async def write(self, data: bytes) -> None:
self.stream_writer.write(data)
await self.stream_writer.drain()
async def close(self) -> None:
self.stream_writer.close()
await self._stream_writer_wait_closed()
async def _stream_writer_wait_closed(self) -> None:
"""
StreamWriter does not have a `wait_closed` method prior to python
3.7. For 3.6 compatibility we have to hack it with an async busy
loop that waits
- first for the transport to be aware that it is closing
- then for the socket to become unmapped
This approach is inspired by the solution in qemu.aqmp.util.
"""
if sys.version_info >= (3, 7):
return await self.stream_writer.wait_closed()
while not self.stream_writer.transport.is_closing():
await asyncio.sleep(0)
transport_socket: sys.IO = self.stream_writer.transport.get_extra_info("socket")
if transport_socket is not None:
while transport_socket.fileno() != -1:
await asyncio.sleep(0)
@asynccontextmanager
async def connect(
socket_path: Path, buffer_size: Optional[int] = None
) -> AsyncIterator[Tuple[BytesReader, BytesWriter]]:
"""
Connect to the socket at given path. Once connected, create an input and
an output stream from the socket. Both the input stream and the output
stream are in raw binary mode. The API is intended to be used like this:
```
async with connect(socket_path) as (input_stream, output_stream):
# Read from input_stream and write into output_stream here
...
```
The optional `buffer_size` argument determines the size of the input buffer
used by the returned reader instance. If not specified, a default value of
64kb will be used.
Socket creation, connection, and closure will be automatically handled
inside this context manager. If any of the socket operations fail, raise
`ConnectionFailure`.
"""
writer: Optional[BytesWriter] = None
try:
limit = buffer_size if buffer_size is not None else 2 ** 16
stream_reader, stream_writer = await asyncio.open_unix_connection(
str(socket_path), limit=limit
)
reader = StreamBytesReader(stream_reader)
writer = StreamBytesWriter(stream_writer)
yield reader, writer
except OSError as error:
raise ConnectionFailure() from error
finally:
if writer is not None:
await writer.close()
@asynccontextmanager
async def connect_in_text_mode(
socket_path: Path, buffer_size: Optional[int] = None
) -> AsyncIterator[Tuple[TextReader, TextWriter]]:
"""
This is a line-oriented higher-level API than `connect`. It can be used
when the caller does not want to deal with the complexity of binary I/O.
The behavior is the same as `connect`, except the streams that are created
operates in text mode. Read/write APIs of the streams uses UTF-8 encoded
`str` instead of `bytes`.
"""
async with connect(socket_path, buffer_size) as (bytes_reader, bytes_writer):
yield (
TextReader(bytes_reader, encoding="utf-8"),
TextWriter(bytes_writer, encoding="utf-8"),
)
async def create_async_stdin_stdout() -> Tuple[TextReader, TextWriter]:
"""
By default, `sys.stdin` and `sys.stdout` are synchronous channels: reading
from `sys.stdin` or writing to `sys.stdout` will block until the read/write
succeed, which is very different from the async socket channels created via
`connect` or `connect_in_text_mode`.
This function creates wrappers around `sys.stdin` and `sys.stdout` and makes
them behave in the same way as other async socket channels. This makes it
easier to write low-level-I/O-agonstic code, where the high-level logic does
not need to worry about whether the underlying async I/O channel comes from
sockets or from stdin/stdout.
"""
loop = asyncio.get_event_loop()
stream_reader = asyncio.StreamReader(loop=loop)
await loop.connect_read_pipe(
lambda: asyncio.StreamReaderProtocol(stream_reader), sys.stdin
)
w_transport, w_protocol = await loop.connect_write_pipe(
asyncio.streams.FlowControlMixin, sys.stdout
)
stream_writer = asyncio.StreamWriter(w_transport, w_protocol, stream_reader, loop)
return (
TextReader(StreamBytesReader(stream_reader)),
TextWriter(StreamBytesWriter(stream_writer)),
)
class BackgroundTask(abc.ABC):
@abc.abstractmethod
async def run(self) -> None:
raise NotImplementedError
class BackgroundTaskManager:
"""
This class manages the lifetime of a given background task.
It maintains one piece of internal state: the existence of an ongoing
task, represented as an attribute of type `Optional[Future]`. When the
attribute is not `None`, it means that the task is actively running in the
background.
"""
_task: BackgroundTask
_ongoing: "Optional[asyncio.Future[None]]"
def __init__(self, task: BackgroundTask) -> None:
"""
Initialize a background task manager. The `task` parameter is expected
to be a coroutine which will be executed when `ensure_task_running()`
method is invoked.
It is expected that the provided task does not internally swallow asyncio
`CancelledError`. Otherwise, task shutdown may not work properly.
"""
self._task = task
self._ongoing = None
async def _run_task(self) -> None:
try:
await self._task.run()
except asyncio.CancelledError:
LOG.info("Terminate background task on explicit cancelling request.")
except Exception as error:
LOG.error(f"Background task unexpectedly quited: {error}")
finally:
self._ongoing = None
def is_task_running(self) -> bool:
return self._ongoing is not None
async def ensure_task_running(self) -> None:
"""
If the background task is not currently running, schedule it to run
in the future by adding the task to the event loop. Note that the
scheduled task won't get a chance to execute unless control is somehow
yield to the event loop from the current task (e.g. via an `await` on
something).
"""
if self._ongoing is None:
self._ongoing = asyncio.ensure_future(self._run_task())
async def ensure_task_stop(self) -> None:
"""
If the background task is running actively, make sure it gets stopped.
"""
ongoing = self._ongoing
if ongoing is not None:
try:
ongoing.cancel()
await ongoing
except asyncio.CancelledError:
# This catch is needed when `ongoing.cancel` is called before
# `_run_task` gets a chance to execute.
LOG.info("Terminate background task on explicit cancelling request.")
finally:
self._ongoing = None