# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import abc
import asyncio
import logging
import sys
from pathlib import Path
from typing import AsyncIterator, Tuple, Optional, List

if sys.version_info >= (3, 7):
    from contextlib import asynccontextmanager
else:
    from async_generator import asynccontextmanager

LOG: logging.Logger = logging.getLogger(__name__)


class ConnectionFailure(Exception):
    pass


class BytesReader(abc.ABC):
    """
    This class defines the basic interface for async I/O input channel.
    """

    @abc.abstractmethod
    async def read_until(self, separator: bytes = b"\n") -> bytes:
        """
        Read data from the stream until `separator` is found.
        If EOF is reached before the complete separator is found, raise
        `asyncio.IncompleteReadError`.
        """
        raise NotImplementedError

    @abc.abstractmethod
    async def read_exactly(self, count: int) -> bytes:
        """
        Read exactly `count` bytes.
        If EOF is reached before the complete separator is found, raise
        `asyncio.IncompleteReadError`.
        """
        raise NotImplementedError

    async def readline(self) -> bytes:
        """
        Read one line, where "line" is a sequence of bytes ending with '\n'.
        If EOF is received and '\n' was not found, the method returns partially
        read data.
        """
        try:
            return await self.read_until(b"\n")
        except asyncio.IncompleteReadError as error:
            return error.partial


class BytesWriter(abc.ABC):
    """
    This class defines the basic interface for async I/O output channel.
    """

    @abc.abstractmethod
    async def write(self, data: bytes) -> None:
        """
        The method attempts to write the data to the underlying channel and
        flushes immediately.
        """
        raise NotImplementedError

    @abc.abstractmethod
    async def close(self) -> None:
        """
        The method closes the underlying channel and wait until the channel is
        fully closed.
        """
        raise NotImplementedError


class TextReader:
    """
    An adapter for `BytesReader` that decodes everything it reads immediately
    from bytes to string. In other words, it tries to expose the same interfaces
    as `BytesReader` except it operates on strings rather than bytestrings.
    """

    bytes_reader: BytesReader
    encoding: str

    def __init__(self, bytes_reader: BytesReader, encoding: str = "utf-8") -> None:
        self.bytes_reader = bytes_reader
        self.encoding = encoding

    async def read_until(self, separator: str = "\n") -> str:
        separator_bytes = separator.encode(self.encoding)
        result_bytes = await self.bytes_reader.read_until(separator_bytes)
        return result_bytes.decode(self.encoding)

    async def read_exactly(self, count: int) -> str:
        result_bytes = await self.bytes_reader.read_exactly(count)
        return result_bytes.decode(self.encoding)

    async def readline(self) -> str:
        result_bytes = await self.bytes_reader.readline()
        return result_bytes.decode(self.encoding)


class TextWriter:
    """
    An adapter for `BytesWriter` that encodes everything it writes immediately
    from string to bytes. In other words, it tries to expose the same interfaces
    as `BytesWriter` except it operates on strings rather than bytestrings.
    """

    bytes_writer: BytesWriter
    encoding: str

    def __init__(self, bytes_writer: BytesWriter, encoding: str = "utf-8") -> None:
        self.bytes_writer = bytes_writer
        self.encoding = encoding

    async def write(self, data: str) -> None:
        data_bytes = data.encode(self.encoding)
        await self.bytes_writer.write(data_bytes)


class MemoryBytesReader(BytesReader):
    """
    An implementation of `BytesReader` based on a given in-memory byte string.
    """

    _data: bytes
    _cursor: int

    def __init__(self, data: bytes) -> None:
        self._data = data
        self._cursor = 0

    async def read_until(self, separator: bytes = b"\n") -> bytes:
        result = bytearray()
        start_index = self._cursor
        end_index = len(self._data)
        for index in range(start_index, end_index):
            byte = self._data[index]
            result.append(byte)
            if result.endswith(separator):
                self._cursor = index + 1
                return bytes(result)

        self._cursor = end_index
        raise asyncio.IncompleteReadError(bytes(result), None)

    async def read_exactly(self, count: int) -> bytes:
        old_cursor = self._cursor
        new_cursor = self._cursor + count
        data_size = len(self._data)
        if new_cursor <= data_size:
            self._cursor = new_cursor
            return self._data[old_cursor:new_cursor]
        else:
            self._cursor = data_size
            raise asyncio.IncompleteReadError(self._data[old_cursor:], count)

    def reset(self) -> None:
        self._cursor = 0


def create_memory_text_reader(data: str, encoding: str = "utf-8") -> TextReader:
    return TextReader(MemoryBytesReader(data.encode(encoding)), encoding)


class MemoryBytesWriter(BytesWriter):
    _items: List[bytes]

    def __init__(self) -> None:
        self._items = []

    async def write(self, data: bytes) -> None:
        self._items.append(data)

    async def close(self) -> None:
        pass

    def items(self) -> List[bytes]:
        return self._items


def create_memory_text_writer(encoding: str = "utf-8") -> TextWriter:
    return TextWriter(MemoryBytesWriter())


class StreamBytesReader(BytesReader):
    """
    An implementation of `BytesReader` based on `asyncio.StreamReader`.
    """

    stream_reader: asyncio.StreamReader

    def __init__(self, stream_reader: asyncio.StreamReader) -> None:
        self.stream_reader = stream_reader

    async def read_until(self, separator: bytes = b"\n") -> bytes:
        # StreamReader.readuntil() may raise when its internal buffer cannot hold
        # all the input data. We need to explicitly handle the raised exceptions
        # by "parking" all partial-read results in memory.
        chunks = []
        while True:
            try:
                chunk = await self.stream_reader.readuntil(separator)
                chunks.append(chunk)
                break
            except asyncio.LimitOverrunError as error:
                chunk = await self.stream_reader.readexactly(error.consumed)
                chunks.append(chunk)
        return b"".join(chunks)

    async def read_exactly(self, count: int) -> bytes:
        return await self.stream_reader.readexactly(count)


class StreamBytesWriter(BytesWriter):
    """
    An implementation of `BytesWriter` based on `asyncio.StreamWriter`.
    """

    stream_writer: asyncio.StreamWriter

    def __init__(self, stream_writer: asyncio.StreamWriter) -> None:
        self.stream_writer = stream_writer

    async def write(self, data: bytes) -> None:
        self.stream_writer.write(data)
        await self.stream_writer.drain()

    async def close(self) -> None:
        self.stream_writer.close()
        await self._stream_writer_wait_closed()

    async def _stream_writer_wait_closed(self) -> None:
        """
        StreamWriter does not have a `wait_closed` method prior to python
        3.7. For 3.6 compatibility we have to hack it with an async busy
        loop that waits
        - first for the transport to be aware that it is closing
        - then for the socket to become unmapped

        This approach is inspired by the solution in qemu.aqmp.util.
        """
        if sys.version_info >= (3, 7):
            return await self.stream_writer.wait_closed()

        while not self.stream_writer.transport.is_closing():
            await asyncio.sleep(0)

        transport_socket: sys.IO = self.stream_writer.transport.get_extra_info("socket")
        if transport_socket is not None:
            while transport_socket.fileno() != -1:
                await asyncio.sleep(0)


@asynccontextmanager
async def connect(
    socket_path: Path, buffer_size: Optional[int] = None
) -> AsyncIterator[Tuple[BytesReader, BytesWriter]]:
    """
    Connect to the socket at given path. Once connected, create an input and
    an output stream from the socket. Both the input stream and the output
    stream are in raw binary mode. The API is intended to be used like this:

    ```
    async with connect(socket_path) as (input_stream, output_stream):
        # Read from input_stream and write into output_stream here
        ...
    ```

    The optional `buffer_size` argument determines the size of the input buffer
    used by the returned reader instance. If not specified, a default value of
    64kb will be used.

    Socket creation, connection, and closure will be automatically handled
    inside this context manager. If any of the socket operations fail, raise
    `ConnectionFailure`.
    """
    writer: Optional[BytesWriter] = None
    try:
        limit = buffer_size if buffer_size is not None else 2 ** 16
        stream_reader, stream_writer = await asyncio.open_unix_connection(
            str(socket_path), limit=limit
        )
        reader = StreamBytesReader(stream_reader)
        writer = StreamBytesWriter(stream_writer)
        yield reader, writer
    except OSError as error:
        raise ConnectionFailure() from error
    finally:
        if writer is not None:
            await writer.close()


@asynccontextmanager
async def connect_in_text_mode(
    socket_path: Path, buffer_size: Optional[int] = None
) -> AsyncIterator[Tuple[TextReader, TextWriter]]:
    """
    This is a line-oriented higher-level API than `connect`. It can be used
    when the caller does not want to deal with the complexity of binary I/O.

    The behavior is the same as `connect`, except the streams that are created
    operates in text mode. Read/write APIs of the streams uses UTF-8 encoded
    `str` instead of `bytes`.
    """
    async with connect(socket_path, buffer_size) as (bytes_reader, bytes_writer):
        yield (
            TextReader(bytes_reader, encoding="utf-8"),
            TextWriter(bytes_writer, encoding="utf-8"),
        )


async def create_async_stdin_stdout() -> Tuple[TextReader, TextWriter]:
    """
    By default, `sys.stdin` and `sys.stdout` are synchronous channels: reading
    from `sys.stdin` or writing to `sys.stdout` will block until the read/write
    succeed, which is very different from the async socket channels created via
    `connect` or `connect_in_text_mode`.

    This function creates wrappers around `sys.stdin` and `sys.stdout` and makes
    them behave in the same way as other async socket channels. This makes it
    easier to write low-level-I/O-agonstic code, where the high-level logic does
    not need to worry about whether the underlying async I/O channel comes from
    sockets or from stdin/stdout.
    """
    loop = asyncio.get_event_loop()
    stream_reader = asyncio.StreamReader(loop=loop)
    await loop.connect_read_pipe(
        lambda: asyncio.StreamReaderProtocol(stream_reader), sys.stdin
    )
    w_transport, w_protocol = await loop.connect_write_pipe(
        asyncio.streams.FlowControlMixin, sys.stdout
    )
    stream_writer = asyncio.StreamWriter(w_transport, w_protocol, stream_reader, loop)
    return (
        TextReader(StreamBytesReader(stream_reader)),
        TextWriter(StreamBytesWriter(stream_writer)),
    )


class BackgroundTask(abc.ABC):
    @abc.abstractmethod
    async def run(self) -> None:
        raise NotImplementedError


class BackgroundTaskManager:
    """
    This class manages the lifetime of a given background task.

    It maintains one piece of internal state: the existence of an ongoing
    task, represented as an attribute of type `Optional[Future]`. When the
    attribute is not `None`, it means that the task is actively running in the
    background.
    """

    _task: BackgroundTask
    _ongoing: "Optional[asyncio.Future[None]]"

    def __init__(self, task: BackgroundTask) -> None:
        """
        Initialize a background task manager. The `task` parameter is expected
        to be a coroutine which will be executed when `ensure_task_running()`
        method is invoked.

        It is expected that the provided task does not internally swallow asyncio
        `CancelledError`. Otherwise, task shutdown may not work properly.
        """
        self._task = task
        self._ongoing = None

    async def _run_task(self) -> None:
        try:
            await self._task.run()
        except asyncio.CancelledError:
            LOG.info("Terminate background task on explicit cancelling request.")
        except Exception as error:
            LOG.error(f"Background task unexpectedly quited: {error}")
        finally:
            self._ongoing = None

    def is_task_running(self) -> bool:
        return self._ongoing is not None

    async def ensure_task_running(self) -> None:
        """
        If the background task is not currently running, schedule it to run
        in the future by adding the task to the event loop. Note that the
        scheduled task won't get a chance to execute unless control is somehow
        yield to the event loop from the current task (e.g. via an `await` on
        something).
        """
        if self._ongoing is None:
            self._ongoing = asyncio.ensure_future(self._run_task())

    async def ensure_task_stop(self) -> None:
        """
        If the background task is running actively, make sure it gets stopped.
        """
        ongoing = self._ongoing
        if ongoing is not None:
            try:
                ongoing.cancel()
                await ongoing
            except asyncio.CancelledError:
                # This catch is needed when `ongoing.cancel` is called before
                # `_run_task` gets a chance to execute.
                LOG.info("Terminate background task on explicit cancelling request.")
            finally:
                self._ongoing = None
