ingestion-edge/pubsub_emulator.py (254 lines of code) (raw):
"""Pubsub Emulator for testing."""
from google.cloud.pubsub_v1.proto import pubsub_pb2_grpc, pubsub_pb2
from google.protobuf import empty_pb2, json_format
from typing import Dict, Iterator, List, Optional, Set
import concurrent.futures
import grpc
import json
import logging
import os
import time
import uuid
class LazyFormat:
"""Container class for lazily formatting logged protobuf."""
def __init__(self, value):
"""Initialize new container."""
self.value = value
def __str__(self):
"""Get str(dict(value)) without surrounding curly braces."""
return str(json_format.MessageToDict(self.value))[1:-1]
class Subscription:
"""Container class for subscription messages."""
def __init__(self):
"""Initialize subscription messages queue."""
self.published = []
self.pulled = {}
class PubsubEmulator(
pubsub_pb2_grpc.PublisherServicer, pubsub_pb2_grpc.SubscriberServicer
):
"""Pubsub gRPC emulator for testing."""
def __init__(
self,
host: str = os.environ.get("HOST", "0.0.0.0"),
max_workers: int = int(os.environ.get("MAX_WORKERS", 1)),
port: int = int(os.environ.get("PORT", 0)),
topics: Optional[str] = os.environ.get("TOPICS"),
):
"""Initialize a new PubsubEmulator and add it to a gRPC server."""
self.logger = logging.getLogger("pubsub_emulator")
self.topics: Dict[str, Set[Subscription]] = {
topic: set() for topic in (json.loads(topics) if topics else [])
}
self.subscriptions: Dict[str, Subscription] = {}
self.status_codes: Dict[str, grpc.StatusCode] = {}
self.sleep: Optional[float] = None
self.host = host
self.port = port
self.max_workers = max_workers
self.create_server()
def create_server(self):
"""Create and start a new grpc.Server configured with PubsubEmulator."""
self.server = grpc.server(
concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers),
options=[
("grpc.max_receive_message_length", -1),
("grpc.max_send_message_length", -1),
],
)
self.port = self.server.add_insecure_port("%s:%d" % (self.host, self.port))
pubsub_pb2_grpc.add_PublisherServicer_to_server(self, self.server)
pubsub_pb2_grpc.add_SubscriberServicer_to_server(self, self.server)
self.server.start()
self.logger.info(
"Listening on %s:%d",
self.host,
self.port,
extra={"host": self.host, "port": self.port},
)
def CreateTopic(
self, request: pubsub_pb2.Topic, context: grpc.ServicerContext
): # noqa: D403
"""CreateTopic implementation."""
self.logger.debug("CreateTopic(%s)", LazyFormat(request))
if request.name in self.topics:
context.abort(grpc.StatusCode.ALREADY_EXISTS, "Topic already exists")
self.topics[request.name] = set()
return request
def DeleteTopic(
self, request: pubsub_pb2.DeleteTopicRequest, context: grpc.ServicerContext
): # noqa: D403
"""DeleteTopic implementation."""
self.logger.debug("DeleteTopic(%s)", LazyFormat(request))
try:
self.topics.pop(request.topic)
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
return empty_pb2.Empty()
def CreateSubscription(
self, request: pubsub_pb2.Subscription, context: grpc.ServicerContext
): # noqa: D403
"""CreateSubscription implementation."""
self.logger.debug("CreateSubscription(%s)", LazyFormat(request))
if request.name in self.subscriptions:
context.abort(grpc.StatusCode.ALREADY_EXISTS, "Subscription already exists")
elif request.topic not in self.topics:
context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
subscription = Subscription()
self.subscriptions[request.name] = subscription
self.topics[request.topic].add(subscription)
return request
def DeleteSubscription(
self,
request: pubsub_pb2.DeleteSubscriptionRequest,
context: grpc.ServicerContext,
): # noqa: D403
"""DeleteSubscription implementation."""
self.logger.debug("DeleteSubscription(%s)", LazyFormat(request))
try:
subscription = self.subscriptions.pop(request.subscription)
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
for subscriptions in self.topics.values():
subscriptions.discard(subscription)
return empty_pb2.Empty()
def Publish(
self, request: pubsub_pb2.PublishRequest, context: grpc.ServicerContext
):
"""Publish implementation."""
self.logger.debug("Publish(%.100s)", LazyFormat(request))
if request.topic in self.status_codes:
context.abort(self.status_codes[request.topic], "Override")
message_ids: List[str] = []
try:
subscriptions = self.topics[request.topic]
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Topic not found")
message_ids = [uuid.uuid4().hex for _ in request.messages]
if self.sleep is not None:
time.sleep(self.sleep)
# return a valid response without recording messages
return pubsub_pb2.PublishResponse(message_ids=message_ids)
for _id, message in zip(message_ids, request.messages):
message.message_id = _id
for subscription in subscriptions:
subscription.published.extend(request.messages)
return pubsub_pb2.PublishResponse(message_ids=message_ids)
def Pull(self, request: pubsub_pb2.PullRequest, context: grpc.ServicerContext):
"""Pull implementation."""
self.logger.debug("Pull(%.100s)", LazyFormat(request))
received_messages: List[pubsub_pb2.ReceivedMessage] = []
try:
subscription = self.subscriptions[request.subscription]
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
messages = subscription.published[: request.max_messages or 100]
subscription.pulled.update(
{message.message_id: message for message in messages}
)
for message in messages:
try:
subscription.published.remove(message)
except ValueError:
pass
received_messages = [
pubsub_pb2.ReceivedMessage(ack_id=message.message_id, message=message)
for message in messages
]
return pubsub_pb2.PullResponse(received_messages=received_messages)
def Acknowledge(
self, request: pubsub_pb2.AcknowledgeRequest, context: grpc.ServicerContext
):
"""Acknowledge implementation."""
self.logger.debug("Acknowledge(%s)", LazyFormat(request))
try:
subscription = self.subscriptions[request.subscription]
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
for ack_id in request.ack_ids:
try:
subscription.pulled.pop(ack_id)
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
return empty_pb2.Empty()
def ModifyAckDeadline(
self,
request: pubsub_pb2.ModifyAckDeadlineRequest,
context: grpc.ServicerContext,
) -> empty_pb2.Empty: # noqa: D403
"""ModifyAckDeadline implementation."""
self.logger.debug("ModifyAckDeadline(%s)", LazyFormat(request))
try:
subscription = self.subscriptions[request.subscription]
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Subscription not found")
# deadline is not tracked so only handle expiration when set to 0
if request.ack_deadline_seconds == 0:
for ack_id in request.ack_ids:
try:
# move message from pulled back to published
subscription.published.append(subscription.pulled.pop(ack_id))
except KeyError:
context.abort(grpc.StatusCode.NOT_FOUND, "Ack ID not found")
return empty_pb2.Empty()
def StreamingPull(
self,
request_iterator: Iterator[pubsub_pb2.StreamingPullRequest],
context: grpc.ServicerContext,
): # noqa: D403
"""StreamingPull implementation."""
for request in request_iterator:
self.logger.debug("StreamingPull(%.100s)", LazyFormat(request))
if request.ack_ids:
self.Acknowledge(
pubsub_pb2.AcknowledgeRequest(
subscription=request.subscription, ack_ids=request.ack_ids
),
context,
)
if request.modify_deadline_seconds:
for ack_id, seconds in zip(
request.modify_deadline_ack_ids, request.modify_deadline_seconds
):
self.ModifyAckDeadline(
pubsub_pb2.ModifyAckDeadlineRequest(
subscription=request.subscription,
ack_ids=[ack_id],
seconds=seconds,
),
context,
)
yield pubsub_pb2.StreamingPullResponse(
received_messages=self.Pull(
pubsub_pb2.PullRequest(
subscription=request.subscription, max_messages=100
),
context,
).received_messages
)
def UpdateTopic(
self, request: pubsub_pb2.UpdateTopicRequest, context: grpc.ServicerContext
):
"""Repurpose UpdateTopic API for setting up test conditions.
:param request.topic.name: Name of the topic that needs overrides.
:param request.update_mask.paths: A list of overrides, of the form
"key=value".
Valid override keys are "status_code" and "sleep". An override value of
"" disables the override.
For the override key "status_code" the override value indicates the
status code that should be returned with an empty response by Publish
requests, and non-empty override values must be a property of
`grpc.StatusCode` such as "UNIMPLEMENTED".
For the override key "sleep" the override value indicates a number of
seconds Publish requests should sleep before returning, and non-empty
override values must be a valid float. Publish requests will return
a valid response without recording messages.
"""
self.logger.debug("UpdateTopic(%s)", LazyFormat(request))
for override in request.update_mask.paths:
key, value = override.split("=", 1)
if key.lower() in ("status_code", "statuscode"):
if value:
try:
self.status_codes[request.topic.name] = getattr(
grpc.StatusCode, value.upper()
)
except AttributeError:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT, "Invalid status code"
)
else:
try:
del self.status_codes[request.topic.name]
except KeyError:
context.abort(
grpc.StatusCode.NOT_FOUND, "Status code override not found"
)
elif key.lower() == "sleep":
if value:
try:
self.sleep = float(value)
except ValueError:
context.abort(
grpc.StatusCode.INVALID_ARGUMENT, "Invalid sleep time"
)
else:
self.sleep = None
else:
context.abort(grpc.StatusCode.Not_FOUND, "Path not found")
return request.topic
def main():
"""Run PubsubEmulator gRPC server."""
# configure logging
logger = logging.getLogger("pubsub_emulator")
logger.addHandler(logging.StreamHandler())
logger.setLevel(getattr(logging, os.environ.get("LOG_LEVEL", "DEBUG").upper()))
# start server
server = PubsubEmulator().server
try:
while True:
time.sleep(60)
except KeyboardInterrupt:
server.stop(grace=None)
if __name__ == "__main__":
main()