services/ui_backend_service/api/ws.py (176 lines of code) (raw):

import os import json import time import asyncio import collections from aiohttp import web, WSMsgType from typing import List, Dict, Any, Callable from .utils import resource_conditions, TTLQueue, postprocess_chain from services.utils import logging from pyee import AsyncIOEventEmitter from ..data.refiner import TaskRefiner, ArtifactRefiner from throttler import throttle_simultaneous from services.data.db_utils import DBResponse from services.data.tagging_utils import apply_run_tags_to_db_response WS_QUEUE_TTL_SECONDS = os.environ.get("WS_QUEUE_TTL_SECONDS", 60 * 5) # 5 minute TTL by default WS_POSTPROCESS_CONCURRENCY_LIMIT = int(os.environ.get("WS_POSTPROCESS_CONCURRENCY_LIMIT", 8)) SUBSCRIBE = 'SUBSCRIBE' UNSUBSCRIBE = 'UNSUBSCRIBE' WSSubscription = collections.namedtuple( "WSSubscription", "ws disconnected_ts fullpath resource query uuid filter") class Websocket(object): ''' Adds a '/ws' endpoint and support for broadcasting realtime resource events to subscribed frontend clients. Subscribe to runs created by user dipper: /runs?_tags=user:dipper 'uuid' can be used to identify specific subscription. Subscribe to future events: {"type": "SUBSCRIBE", "uuid": "myst3rySh4ck", "resource": "/runs"} Subscribing to future events and return past data since unix time (seconds): {"type": "SUBSCRIBE", "uuid": "myst3rySh4ck", "resource": "/runs", "since": 1602752197} Unsubscribe: {"type": "UNSUBSCRIBE", "uuid": "myst3rySh4ck"} Example event: {"type": "UPDATE", "uuid": "myst3rySh4ck", "resource": "/runs", "data": {"foo": "bar"}} ''' subscriptions: List[WSSubscription] = [] def __init__(self, app, db, event_emitter=None, queue_ttl: int = WS_QUEUE_TTL_SECONDS, cache=None): self.event_emitter = event_emitter or AsyncIOEventEmitter() self.db = db self.queue = TTLQueue(queue_ttl) self.task_refiner = TaskRefiner(cache=cache.artifact_cache) if cache else None self.artifact_refiner = ArtifactRefiner(cache=cache.artifact_cache) if cache else None self.logger = logging.getLogger("Websocket") event_emitter.on('notify', self.event_handler) app.router.add_route('GET', '/ws', self.websocket_handler) self.loop = asyncio.get_event_loop() async def event_handler(self, operation: str, resources: List[str], data: Dict, table_name: str = None, filter_dict: Dict = {}): """ Event handler for websocket events on 'notify'. Either receives raw data from table triggers listener and either performs a database load before broadcasting from the provided table, or receives predefined data and broadcasts it as-is. Parameters ---------- operation : str name of the operation related to the DB event, either 'INSERT' or 'UPDATE' resources : List[str] List of resource paths that this event is related to. Used strictly for broadcasting to websocket subscriptions data : Dict The data of the record to be broadcast. Can either be complete, or partial. In case of partial data (and a provided table name) this is only used for the DB query. table_name : str (optional) name of the table that the complete data should be queried from. filter_dict : Dict (optional) a dictionary of filters used in the query when fetching complete data. """ # Check if event needs to be broadcast (if anyone is subscribed to the resource) if any(subscription.resource in resources for subscription in self.subscriptions): # load the data and postprocessor for broadcasting if table # is provided (otherwise data has already been loaded in advance) if table_name: table = self.db.get_table_by_name(table_name) _postprocess = await self.get_table_postprocessor(table_name) _data = await load_data_from_db(table, data, filter_dict, postprocess=_postprocess) else: _data = data if not _data: # Skip sending this event to subscriptions in case data is None or empty. # This could be caused by insufficient/broken data and can break the UI. return # Append event to the queue so that we can later dispatch them in case of disconnections # # NOTE: server instance specific ws queue will not work when scaling across multiple instances. # but on the other hand loading data and pushing everything into the queue for every server instance is also # a suboptimal solution. await self.queue.append({ 'operation': operation, 'resources': resources, 'data': _data }) for subscription in self.subscriptions: try: if subscription.disconnected_ts and time.time() - subscription.disconnected_ts > WS_QUEUE_TTL_SECONDS: await self.unsubscribe_from(subscription.ws, subscription.uuid) else: await self._event_subscription(subscription, operation, resources, _data) except ConnectionResetError: self.logger.debug("Trying to broadcast to a stale subscription. Unsubscribing") await self.unsubscribe_from(subscription.ws, subscription.uuid) except Exception: self.logger.exception("Broadcasting to subscription failed") async def _event_subscription(self, subscription: WSSubscription, operation: str, resources: List[str], data: Dict): for resource in resources: if subscription.resource == resource: # Check if possible filters match this event # only if the subscription actually provided conditions. if subscription.filter: filters_match_request = subscription.filter(data) else: filters_match_request = True if filters_match_request: payload = {'type': operation, 'uuid': subscription.uuid, 'resource': resource, 'data': data} await subscription.ws.send_str(json.dumps(payload)) async def subscribe_to(self, ws, uuid: str, resource: str, since: int): # Always unsubscribe existing duplicate identifiers await self.unsubscribe_from(ws, uuid) # Create new subscription _resource, query, filter_fn = resource_conditions(resource) subscription = WSSubscription( ws=ws, fullpath=resource, resource=_resource, query=query, uuid=uuid, filter=filter_fn, disconnected_ts=None) self.subscriptions.append(subscription) # Send previous events that client might have missed due to disconnection if since: # Subtract 1 second to make sure all events are included event_queue = await self.queue.values_since(since) for _, event in event_queue: self.loop.create_task( self._event_subscription(subscription, event['operation'], event['resources'], event['data']) ) async def unsubscribe_from(self, ws, uuid: str = None): if uuid: self.subscriptions = list( filter(lambda s: uuid != s.uuid or ws != s.ws, self.subscriptions)) else: self.subscriptions = list( filter(lambda s: ws != s.ws, self.subscriptions)) async def handle_disconnect(self, ws): """ Sets disconnected timestamp on websocket subscription without removing it from the list. Removing is handled by event_handler that checks for expired subscriptions before emitting """ self.subscriptions = list( map( lambda sub: sub._replace(disconnected_ts=time.time()) if sub.ws == ws else sub, self.subscriptions) ) async def websocket_handler(self, request): "Handler for received messages from the open Web Socket connection." # TODO: Consider using options autoping=True and heartbeat=20 if supported by clients. ws = web.WebSocketResponse() await ws.prepare(request) while not ws.closed: async for msg in ws: if msg.type == WSMsgType.TEXT: try: # Custom ping message handling. # If someone is pinging, lets answer with pong rightaway. if msg.data == "__ping__": await ws.send_str("__pong__") else: payload = json.loads(msg.data) op_type = payload.get("type") resource = payload.get("resource") uuid = payload.get("uuid") since = payload.get("since") if since is not None and str(since).isnumeric(): since = int(since) else: since = None if op_type == SUBSCRIBE and uuid and resource: await self.subscribe_to(ws, uuid, resource, since) elif op_type == UNSUBSCRIBE and uuid: await self.unsubscribe_from(ws, uuid) except Exception: self.logger.exception("Exception occurred.") # Always remove clients from listeners await self.handle_disconnect(ws) return ws @throttle_simultaneous(count=8) async def get_table_postprocessor(self, table_name): refiner_postprocess = None table = None if table_name == self.db.task_table_postgres.table_name: table = self.db.run_table_postgres refiner_postprocess = self.task_refiner.postprocess elif table_name == self.db.artifact_table_postgres.table_name: table = self.db.run_table_postgres refiner_postprocess = self.artifact_refiner.postprocess if table: async def _tags_postprocess(db_response: DBResponse, invalidate_cache=False): flow_id = db_response.body.get('flow_id') run_id = db_response.body.get('run_id') or db_response.body.get('run_number') if not flow_id or not run_id: self.logger.warning("Missing flow_id or run_id (or run_number) for a record from table {}".format(table.table_name)) return db_response return await apply_run_tags_to_db_response(flow_id, run_id, table, db_response) return postprocess_chain([_tags_postprocess, refiner_postprocess]) return refiner_postprocess async def load_data_from_db(table, data: Dict[str, Any], filter_dict: Dict = {}, postprocess: Callable = None): # filter the data for loading based on available primary keys conditions_dict = { key: data[key] for key in table.primary_keys if key in data } filter_dict = {**conditions_dict, **filter_dict} conditions, values = [], [] for k, v in filter_dict.items(): conditions.append("{} = %s".format(k)) values.append(v) results, *_ = await table.find_records( conditions=conditions, values=values, fetch_single=True, enable_joins=True, expanded=True, postprocess=postprocess ) return results.body