services/sse-api/src/sse_api/watcher.py (140 lines of code) (raw):

# SPDX-License-Identifier: Apache-2.0 # Copyright 2023 The HuggingFace Authors. import asyncio import contextlib from collections.abc import Mapping, Sequence from dataclasses import dataclass from http import HTTPStatus from typing import Any, Optional from uuid import uuid4 from motor.motor_asyncio import AsyncIOMotorClient from pymongo.errors import PyMongoError from sse_api.constants import HUB_CACHE_KIND DatasetHubCacheResponse = Mapping[str, Any] class ChangeStreamInitError(Exception): pass @dataclass class HubCacheChangedEventValue: dataset: str hub_cache: Optional[DatasetHubCacheResponse] # ^ None if the dataset has been deleted, or the value is an error response class HubCacheChangedEvent(asyncio.Event): """Subclass of asyncio.Event which is able to send a value to the waiter""" _hub_cache_value: Optional[HubCacheChangedEventValue] def __init__(self, *, hub_cache_value: Optional[HubCacheChangedEventValue] = None): super().__init__() self._hub_cache_value = hub_cache_value super().set() def set_value(self, *, hub_cache_value: Optional[HubCacheChangedEventValue] = None) -> None: self._hub_cache_value = hub_cache_value return super().set() async def wait_value(self) -> Optional[HubCacheChangedEventValue]: """The caller is responsible to call self.clear() when the event has been handled""" await super().wait() return self._hub_cache_value @dataclass class HubCachePublisher: _watchers: dict[str, HubCacheChangedEvent] def _notify_change( self, *, dataset: str, hub_cache: Optional[DatasetHubCacheResponse], suscriber: Optional[str] = None, ) -> None: hub_cache_value = HubCacheChangedEventValue(dataset=dataset, hub_cache=hub_cache) for watcher, event in self._watchers.items(): if suscriber is None or suscriber == watcher: event.set_value(hub_cache_value=hub_cache_value) def _unsubscribe(self, uuid: str) -> None: self._watchers.pop(uuid) def _subscribe(self) -> tuple[str, HubCacheChangedEvent]: event = HubCacheChangedEvent() uuid = uuid4().hex self._watchers[uuid] = event return (uuid, event) class HubCacheWatcher: """ Utility to watch the value of the cache entries with kind 'dataset-hub-cache'. """ _watch_task: asyncio.Task[None] # <- not sure about the type def __init__(self, client: AsyncIOMotorClient, db_name: str, collection_name: str) -> None: self._client = client self._collection = self._client[db_name][collection_name] self._publisher = HubCachePublisher(_watchers={}) def run_initialization(self, suscriber: str) -> asyncio.Task[Any]: return asyncio.create_task(self._init_loop(suscriber=suscriber)) def start_watching(self) -> None: self._watch_task = asyncio.create_task(self._watch_loop()) async def stop_watching(self) -> None: self._watch_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._watch_task def subscribe(self) -> tuple[str, HubCacheChangedEvent]: """ Subscribe to random value changes for the given space. The caller is responsible for calling `self.unsubscribe` to release resources. Returns: `tuple[str, HubCacheChangedEvent]`: A 2-tuple containing a UUID and an instance of RandomValueChangedEvent. RandomValueChangedEvent can be `await`ed to be notified of updates to the random value. UUID must be passed when unsubscribing to release the associated resources. """ return self._publisher._subscribe() def unsubscribe(self, uuid: str) -> None: """ Release resources allocated to subscribe to the random value changes. """ pub = self._publisher pub._unsubscribe(uuid) async def _init_loop(self, suscriber: str) -> None: """ publish an event for each initial dataset-hub-cache cache entry. TODO: we don't want to send to all the suscribers """ async for document in self._collection.find( filter={"kind": HUB_CACHE_KIND}, projection={"dataset": 1, "content": 1, "http_status": 1}, sort=[("_id", 1)], batch_size=1, ): # ^ should we use batch_size=100 instead, and send a list of contents? dataset = document["dataset"] self._publisher._notify_change( suscriber=suscriber, dataset=dataset, hub_cache=(document["content"] if document["http_status"] == HTTPStatus.OK else None), ) async def _watch_loop(self) -> None: """ publish a new event, on every change in a dataset-hub-cache cache entry. """ pipeline: Sequence[Mapping[str, Any]] = [ { "$match": { "$or": [ {"fullDocument.kind": HUB_CACHE_KIND}, {"fullDocumentBeforeChange.kind": HUB_CACHE_KIND}, ], "operationType": {"$in": ["insert", "update", "replace", "delete"]}, }, }, { "$project": { "fullDocument": 1, "fullDocumentBeforeChange": 1, "updateDescription": 1, "operationType": 1, }, }, ] resume_token = None while True: try: async with self._collection.watch( pipeline, resume_after=resume_token, full_document="updateLookup", full_document_before_change="whenAvailable", ) as stream: async for change in stream: resume_token = stream.resume_token operation = change["operationType"] if ( operation == "delete" and "fullDocumentBeforeChange" in change and change["fullDocumentBeforeChange"]["kind"] == HUB_CACHE_KIND ): dataset = change["fullDocumentBeforeChange"]["dataset"] self._publisher._notify_change(dataset=dataset, hub_cache=None) continue if change["fullDocument"]["kind"] != HUB_CACHE_KIND: continue if operation == "update" and not any( field in change["updateDescription"]["updatedFields"] for field in ["content", "http_status"] ): # ^ no change, skip continue self._publisher._notify_change( dataset=change["fullDocument"]["dataset"], hub_cache=( change["fullDocument"]["content"] if change["fullDocument"]["http_status"] == HTTPStatus.OK else None ), ) except PyMongoError: if resume_token is None: raise ChangeStreamInitError()