"""Pubsub Emulator for testing."""

from google.cloud.pubsub_v1.proto import pubsub_pb2_grpc, pubsub_pb2
from google.protobuf import empty_pb2, json_format
from typing import Dict, Iterator, List, Optional, Set
import concurrent.futures
import grpc
import json
import logging
import os
import time
import uuid


class LazyFormat:
    """Container class for lazily formatting logged protobuf."""

    def __init__(self, value):
        """Initialize new container."""
        self.value = value

    def __str__(self):
        """Get str(dict(value)) without surrounding curly braces."""
        return str(json_format.MessageToDict(self.value))[1:-1]


class Subscription:
    """Container class for subscription messages."""

    def __init__(self):
        """Initialize subscription messages queue."""
        self.published = []
        self.pulled = {}


class PubsubEmulator(
    pubsub_pb2_grpc.PublisherServicer, pubsub_pb2_grpc.SubscriberServicer
):
    """Pubsub gRPC emulator for testing."""

    def __init__(
        self,
        host: str = os.environ.get("HOST", "0.0.0.0"),
        max_workers: int = int(os.environ.get("MAX_WORKERS", 1)),
        port: int = int(os.environ.get("PORT", 0)),
        topics: Optional[str] = os.environ.get("TOPICS"),
    ):
        """Initialize a new PubsubEmulator and add it to a gRPC server."""
        self.logger = logging.getLogger("pubsub_emulator")
        self.topics: Dict[str, Set[Subscription]] = {
            topic: set() for topic in (json.loads(topics) if topics else [])
        }
        self.subscriptions: Dict[str, Subscription] = {}
        self.status_codes: Dict[str, grpc.StatusCode] = {}
        self.sleep: Optional[float] = None
        self.host = host
        self.port = port
        self.max_workers = max_workers
        self.create_server()

    def create_server(self):
        """Create and start a new grpc.Server configured with PubsubEmulator."""
        self.server = grpc.server(
            concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers),
            options=[
                ("grpc.max_receive_message_length", -1),
                ("grpc.max_send_message_length", -1),
            ],
        )
        self.port = self.server.add_insecure_port("%s:%d" % (self.host, self.port))
        pubsub_pb2_grpc.add_PublisherServicer_to_server(self, self.server)
        pubsub_pb2_grpc.add_SubscriberServicer_to_server(self, self.server)
        self.server.start()
        self.logger.info(
            "Listening on %s:%d",
            self.host,
            self.port,
            extra={"host": self.host, "port": self.port},
        )

    def CreateTopic(
        self, request: pubsub_pb2.Topic, context: grpc.ServicerContext
    ):  # noqa: D403
        """CreateTopic implementation."""
        self.logger.debug("CreateTopic(%s)", LazyFormat(request))
        if request.name in self.topics:
            context.abort(grpc.StatusCode.ALREADY_EXISTS, "Topic already exists")
        self.topics[request.name] = set()
        return request

    def DeleteTopic(
        self, request: pubsub_pb2.DeleteTopicRequest, context: grpc.ServicerContext
    ):  # noqa: D403
        """DeleteTopic implementation."""
        self.logger.debug("DeleteTopic(%s)", LazyFormat(request))
        try:
            self.topics.pop(request.topic)
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
        return empty_pb2.Empty()

    def CreateSubscription(
        self, request: pubsub_pb2.Subscription, context: grpc.ServicerContext
    ):  # noqa: D403
        """CreateSubscription implementation."""
        self.logger.debug("CreateSubscription(%s)", LazyFormat(request))
        if request.name in self.subscriptions:
            context.abort(grpc.StatusCode.ALREADY_EXISTS, "Subscription already exists")
        elif request.topic not in self.topics:
            context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
        subscription = Subscription()
        self.subscriptions[request.name] = subscription
        self.topics[request.topic].add(subscription)
        return request

    def DeleteSubscription(
        self,
        request: pubsub_pb2.DeleteSubscriptionRequest,
        context: grpc.ServicerContext,
    ):  # noqa: D403
        """DeleteSubscription implementation."""
        self.logger.debug("DeleteSubscription(%s)", LazyFormat(request))
        try:
            subscription = self.subscriptions.pop(request.subscription)
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
        for subscriptions in self.topics.values():
            subscriptions.discard(subscription)
        return empty_pb2.Empty()

    def Publish(
        self, request: pubsub_pb2.PublishRequest, context: grpc.ServicerContext
    ):
        """Publish implementation."""
        self.logger.debug("Publish(%.100s)", LazyFormat(request))
        if request.topic in self.status_codes:
            context.abort(self.status_codes[request.topic], "Override")
        message_ids: List[str] = []
        try:
            subscriptions = self.topics[request.topic]
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
        message_ids = [uuid.uuid4().hex for _ in request.messages]
        if self.sleep is not None:
            time.sleep(self.sleep)
            # return a valid response without recording messages
            return pubsub_pb2.PublishResponse(message_ids=message_ids)
        for _id, message in zip(message_ids, request.messages):
            message.message_id = _id
        for subscription in subscriptions:
            subscription.published.extend(request.messages)
        return pubsub_pb2.PublishResponse(message_ids=message_ids)

    def Pull(self, request: pubsub_pb2.PullRequest, context: grpc.ServicerContext):
        """Pull implementation."""
        self.logger.debug("Pull(%.100s)", LazyFormat(request))
        received_messages: List[pubsub_pb2.ReceivedMessage] = []
        try:
            subscription = self.subscriptions[request.subscription]
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
        messages = subscription.published[: request.max_messages or 100]
        subscription.pulled.update(
            {message.message_id: message for message in messages}
        )
        for message in messages:
            try:
                subscription.published.remove(message)
            except ValueError:
                pass
        received_messages = [
            pubsub_pb2.ReceivedMessage(ack_id=message.message_id, message=message)
            for message in messages
        ]
        return pubsub_pb2.PullResponse(received_messages=received_messages)

    def Acknowledge(
        self, request: pubsub_pb2.AcknowledgeRequest, context: grpc.ServicerContext
    ):
        """Acknowledge implementation."""
        self.logger.debug("Acknowledge(%s)", LazyFormat(request))
        try:
            subscription = self.subscriptions[request.subscription]
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
        for ack_id in request.ack_ids:
            try:
                subscription.pulled.pop(ack_id)
            except KeyError:
                context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
        return empty_pb2.Empty()

    def ModifyAckDeadline(
        self,
        request: pubsub_pb2.ModifyAckDeadlineRequest,
        context: grpc.ServicerContext,
    ) -> empty_pb2.Empty:  # noqa: D403
        """ModifyAckDeadline implementation."""
        self.logger.debug("ModifyAckDeadline(%s)", LazyFormat(request))
        try:
            subscription = self.subscriptions[request.subscription]
        except KeyError:
            context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
        # deadline is not tracked so only handle expiration when set to 0
        if request.ack_deadline_seconds == 0:
            for ack_id in request.ack_ids:
                try:
                    # move message from pulled back to published
                    subscription.published.append(subscription.pulled.pop(ack_id))
                except KeyError:
                    context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
        return empty_pb2.Empty()

    def StreamingPull(
        self,
        request_iterator: Iterator[pubsub_pb2.StreamingPullRequest],
        context: grpc.ServicerContext,
    ):  # noqa: D403
        """StreamingPull implementation."""
        for request in request_iterator:
            self.logger.debug("StreamingPull(%.100s)", LazyFormat(request))
            if request.ack_ids:
                self.Acknowledge(
                    pubsub_pb2.AcknowledgeRequest(
                        subscription=request.subscription, ack_ids=request.ack_ids
                    ),
                    context,
                )
            if request.modify_deadline_seconds:
                for ack_id, seconds in zip(
                    request.modify_deadline_ack_ids, request.modify_deadline_seconds
                ):
                    self.ModifyAckDeadline(
                        pubsub_pb2.ModifyAckDeadlineRequest(
                            subscription=request.subscription,
                            ack_ids=[ack_id],
                            seconds=seconds,
                        ),
                        context,
                    )
            yield pubsub_pb2.StreamingPullResponse(
                received_messages=self.Pull(
                    pubsub_pb2.PullRequest(
                        subscription=request.subscription, max_messages=100
                    ),
                    context,
                ).received_messages
            )

    def UpdateTopic(
        self, request: pubsub_pb2.UpdateTopicRequest, context: grpc.ServicerContext
    ):
        """Repurpose UpdateTopic API for setting up test conditions.

        :param request.topic.name: Name of the topic that needs overrides.
        :param request.update_mask.paths: A list of overrides, of the form
        "key=value".

        Valid override keys are "status_code" and "sleep". An override value of
        "" disables the override.

        For the override key "status_code" the override value indicates the
        status code that should be returned with an empty response by Publish
        requests, and non-empty override values must be a property of
        `grpc.StatusCode` such as "UNIMPLEMENTED".

        For the override key "sleep" the override value indicates a number of
        seconds Publish requests should sleep before returning, and non-empty
        override values must be a valid float. Publish requests will return
        a valid response without recording messages.
        """
        self.logger.debug("UpdateTopic(%s)", LazyFormat(request))
        for override in request.update_mask.paths:
            key, value = override.split("=", 1)
            if key.lower() in ("status_code", "statuscode"):
                if value:
                    try:
                        self.status_codes[request.topic.name] = getattr(
                            grpc.StatusCode, value.upper()
                        )
                    except AttributeError:
                        context.abort(
                            grpc.StatusCode.INVALID_ARGUMENT, "Invalid status code"
                        )
                else:
                    try:
                        del self.status_codes[request.topic.name]
                    except KeyError:
                        context.abort(
                            grpc.StatusCode.NOT_FOUND, "Status code override not found"
                        )
            elif key.lower() == "sleep":
                if value:
                    try:
                        self.sleep = float(value)
                    except ValueError:
                        context.abort(
                            grpc.StatusCode.INVALID_ARGUMENT, "Invalid sleep time"
                        )
                else:
                    self.sleep = None
            else:
                context.abort(grpc.StatusCode.Not_FOUND, "Path not found")
        return request.topic


def main():
    """Run PubsubEmulator gRPC server."""
    # configure logging
    logger = logging.getLogger("pubsub_emulator")
    logger.addHandler(logging.StreamHandler())
    logger.setLevel(getattr(logging, os.environ.get("LOG_LEVEL", "DEBUG").upper()))
    # start server
    server = PubsubEmulator().server
    try:
        while True:
            time.sleep(60)
    except KeyboardInterrupt:
        server.stop(grace=None)


if __name__ == "__main__":
    main()
