services/ui_backend_service/api/notify.py (150 lines of code) (raw):
import json
import asyncio
from typing import Dict
from services.utils import logging
from services.data.postgres_async_db import (
FLOW_TABLE_NAME, RUN_TABLE_NAME,
STEP_TABLE_NAME, TASK_TABLE_NAME,
METADATA_TABLE_NAME, ARTIFACT_TABLE_NAME
)
from pyee import AsyncIOEventEmitter
class ListenNotify(object):
"""
Class for starting an async listener task that listens on a DB connection for notifications,
and processes these as events before broadcasting them on the provided event_emitter.
Consumes messages from DB connection with 'LISTEN notify' and processes the contents before passing
to event_emitter.emit('notify', *args)
Parameters
----------
db : AsyncPostgresDB
initialized instance of a postgresDB adapter
event_emitter : AsyncIOEventEmitter
Any event emitter class that implements .emit('notify', *args)
"""
def __init__(self, app, db, event_emitter=None):
self.event_emitter = event_emitter or AsyncIOEventEmitter()
self.db = db
self.logger = logging.getLogger("ListenNotify")
self.loop = asyncio.get_event_loop()
self.loop.create_task(self._init(self.db.pool))
async def _init(self, pool):
while True:
try:
async with pool.acquire() as conn:
self.logger.info("Connection acquired")
await asyncio.gather(
self.listen(conn),
self.ping(conn)
)
except Exception as ex:
self.logger.warning(str(ex))
finally:
await asyncio.sleep(1)
async def listen(self, conn):
async with conn.cursor() as cur:
await cur.execute("LISTEN notify")
while not cur.closed:
try:
msg = conn.notifies.get_nowait()
self.loop.create_task(self.handle_trigger_msg(msg))
except asyncio.QueueEmpty:
await asyncio.sleep(0.1)
except Exception:
self.logger.exception("Exception when listening to notify.")
async def ping(self, conn):
async with conn.cursor() as cur:
while not cur.closed:
try:
await cur.execute("NOTIFY listen")
except Exception:
self.logger.debug("Exception NOTIFY ping.")
finally:
await asyncio.sleep(1)
async def handle_trigger_msg(self, msg: str):
"Handler for the messages received from 'LISTEN notify'"
try:
payload = json.loads(msg.payload)
table_name = payload.get("table")
operation = payload.get("operation")
data = payload.get("data")
table = self.db.get_table_by_name(table_name)
if table is not None:
resources = resource_list(table.table_name, data)
# Broadcast this event to `api/ws.py` (Websocket.event_handler)
# and notify each Websocket connection about this event
if resources is not None and len(resources) > 0:
await _broadcast(self.event_emitter, operation, table, data)
# Heartbeat watcher for Runs.
if table.table_name == self.db.run_table_postgres.table_name:
self.event_emitter.emit('run-heartbeat', 'update', data)
# Heartbeat watcher for Tasks.
if table.table_name == self.db.task_table_postgres.table_name:
self.event_emitter.emit('task-heartbeat', 'update', data)
# also keepalive for run heartbeats.
self.event_emitter.emit('run-heartbeat', 'update', data)
# also broadcast as a run heartbeat update, as otherwise these receive no updates
await _broadcast(self.event_emitter, "UPDATE", self.db.run_table_postgres, data)
# Notify when Run parameters are ready.
if operation == "INSERT" and \
table.table_name == self.db.step_table_postgres.table_name and \
data["step_name"] == "start":
self.event_emitter.emit("preload-run-parameters", data['flow_id'], data['run_number'])
# Notify task resources of a new attempt if 'attempt' metadata is inserted.
if operation == "INSERT" and \
table.table_name == self.db.metadata_table_postgres.table_name and \
data["field_name"] == "attempt":
# Extract the attempt number from metadata attempt value, so we know which task attempt to broadcast.
_attempt_id = int(data.get("value", 0))
# First attempt has already been inserted by task table trigger.
# Later attempts must count as inserts to register properly for the UI
_op = "UPDATE" if _attempt_id == 0 else "INSERT"
await _broadcast(
event_emitter=self.event_emitter,
operation=_op,
table=self.db.task_table_postgres,
data=data,
filter_dict={"attempt_id": _attempt_id}
)
# Notify related resources once attempt_ok for task has been saved.
if operation == "INSERT" and \
table.table_name == self.db.metadata_table_postgres.table_name and \
data["field_name"] == "attempt_ok":
attempt_id = None
try:
attempt_tag = [t for t in data['tags'] if t.startswith('attempt_id')][0]
attempt_id = attempt_tag.split(":")[1]
except Exception:
self.logger.exception("Failed to load attempt_id from attempt_ok metadata")
pass
# remove heartbeat watcher for completed task
self.event_emitter.emit("task-heartbeat", "complete", data)
# broadcast task status as it has either completed or failed.
# TODO: Might be necessary to broadcast with a specific attempt_id.
await _broadcast(
event_emitter=self.event_emitter,
operation="UPDATE",
table=self.db.task_table_postgres,
data=data,
filter_dict={"attempt_id": attempt_id} if attempt_id else {}
)
# Notify updated Run status once attempt_ok metadata for end step has been received
if data["step_name"] == "end":
await _broadcast(self.event_emitter, "UPDATE", self.db.run_table_postgres, data)
# Also trigger preload of artifacts after a run finishes.
self.event_emitter.emit("preload-task-statuses", data['flow_id'], data['run_number'])
# And remove possible heartbeat watchers for completed runs
self.event_emitter.emit("run-heartbeat", "complete", data)
# Notify DAG cache store to preload artifact
if operation == "INSERT" and \
table.table_name == self.db.metadata_table_postgres.table_name and \
data["step_name"] == "start" and \
data["field_name"] in ["code-package-url", "code-package"]:
self.event_emitter.emit("preload-dag", data['flow_id'], data['run_number'])
except Exception:
self.logger.exception("Exception occurred")
def resource_list(table_name: str, data: Dict):
"""
List of RESTful resources that the provided table and data are included in.
Used for determining which Web Socket subscriptions this resource relates to.
Parameters
----------
table_name : str
table name that the Data belongs to
data : Dict
Dictionary of the data for a record of the table.
Returns
-------
List
example:
[
"/runs",
"/flows/ExampleFlow/runs",
"/flows/ExampleFlow/runs/1234"
]
"""
resource_paths = {
FLOW_TABLE_NAME: [
"/flows",
"/flows/{flow_id}"
],
RUN_TABLE_NAME: [
"/runs",
"/flows/{flow_id}/runs",
"/flows/{flow_id}/runs/{run_number}"
],
STEP_TABLE_NAME: [
"/flows/{flow_id}/runs/{run_number}/steps",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}"
],
TASK_TABLE_NAME: [
"/flows/{flow_id}/runs/{run_number}/tasks",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks/{task_id}",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks/{task_id}/attempts"
],
ARTIFACT_TABLE_NAME: [
"/flows/{flow_id}/runs/{run_number}/artifacts",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/artifacts",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks/{task_id}/artifacts",
],
METADATA_TABLE_NAME: [
"/flows/{flow_id}/runs/{run_number}/metadata",
"/flows/{flow_id}/runs/{run_number}/steps/{step_name}/tasks/{task_id}/metadata"
]
}
if table_name in resource_paths:
return [path.format(**data) for path in resource_paths[table_name]]
return []
async def _broadcast(event_emitter, operation: str, table, data: Dict, filter_dict={}):
_resources = resource_list(table.table_name, data)
event_emitter.emit('notify', operation, _resources, data, table.table_name, filter_dict)