projects/streamlit-pubsub/streamlit_pubsub.py (138 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Streamlit and Pub/Sub integration methods. These methods provide both subscription and publish functionality to Streamlit. The subscription is much more sophisticated, as it provides a central cache per-Subscription of received messages and allows multiple Streamlit dashboards to share the same subscription data. """ import asyncio import atexit import datetime import threading import uuid from google import pubsub_v1 from google.api_core.gapic_v1.client_info import ClientInfo from google.cloud import pubsub from google.protobuf import duration_pb2 from google.protobuf import timestamp_pb2 import streamlit as st logger = st.logger.get_logger(__name__) class Buffer: """Fixed size sequence-based buffer. This fixed size buffer enables a subscription to add data, as fast as possible, and then fetch (up to the buffer size) the most recent N messages. A sequence number is maintained so the subscriber does not get duplicates, but may miss messages. """ def __init__(self, max_elem=100): self.max_elem = max_elem self.data = [] self.next_idx = 0 # zero is front of queue self.seq = -1 # zero is the first seqId (-1 is nothing) def add(self, elem): self.seq += 1 if self.next_idx == len(self.data): self.data.append(elem) else: self.data[self.next_idx] = elem self.next_idx += 1 if self.next_idx == self.max_elem: self.next_idx = 0 def get_elems(self, last_seq): """Return all new messages since last_seq, up to max_elem.""" required_elem = self.seq - last_seq if required_elem > self.max_elem: required_elem = self.max_elem if required_elem <= 0: return (self.seq, []) if required_elem <= self.next_idx: start_idx = self.next_idx - required_elem return (self.seq, self.data[start_idx : self.next_idx]) start_end_idx = len(self.data) - required_elem + self.next_idx return ( self.seq, self.data[start_end_idx:] + self.data[: self.next_idx], ) class BufferedAsyncData: """Thread-safe Buffer wrapper with additional async methods. This enables sharing the buffer between threads and exposing it into Streamlit dashboards. """ def __init__(self, max_messages=100): self.buf = Buffer(max_elem=max_messages) self.seq_id = 0 self.cv = threading.Condition() def update(self, new_data: bytes): with self.cv: self.buf.add(new_data) self.cv.notify_all() def get_latest_data(self, last_seq_id: int = -1, timeout: float = None): """Thread-safe fetch of the latest data since last_seq_id. The timeout is used to return early if specified (empty list). Args: last_seq_id: The last sequence id (highest) that was fetched. timeout: The time to wait before returning empty list. Returns: A tuple of new high seq_id and a list of data. If timeout was specified, then the high seq_id may be the same as last_seq_id and the list of data empty. """ while True: with self.cv: seq_id, data = self.buf.get_elems(last_seq_id) # If data is found, return it as is if data: return seq_id, data # If timed out, just return the empty data if not self.cv.wait(timeout=timeout): return seq_id, data async def aget_latest_data( self, last_seq_id: int = -1, timeout: float = None ): return await asyncio.to_thread( self.get_latest_data, last_seq_id, timeout ) async def aget_latest_st_data(self, seq_id_key="key", timeout: float = 1.0): """Get the latest data as Streamlit, using the seq_id_key. This is designed for a Streamlit dashboard using seq_id_key, with a frequent timeout. Args: seq_id_key: Key in the st.session_state for storing the seq_id. timeout: Timeout between loops. Recommended to be fairly small. Returns: List of data when there is new data to process. """ if seq_id_key not in st.session_state: st.session_state[seq_id_key] = -1 while True: (st.session_state[seq_id_key], data) = await self.aget_latest_data( last_seq_id=st.session_state[seq_id_key], timeout=timeout ) # If no data (e.g., timeout occurred), just try again if data: return data # Capture script startup time -- if seeking to now START_TIMESTAMP = timestamp_pb2.Timestamp() START_TIMESTAMP.GetCurrentTime() @st.cache_resource(show_spinner=False) def get_subscriber_client(): # Create the subscription client and path return pubsub.SubscriberClient( # User Agent allows us to track and prioritise # supporting this integration client_info=ClientInfo( user_agent="cloud-solutions/streamlit-pubsub-v1", ) ) @st.cache_resource(show_spinner=True) def get_subscriber( project_id: str, topic: str, max_messages: int, message_retention: datetime.timedelta = datetime.timedelta(minutes=10), subscription_ttl: datetime.timedelta = datetime.timedelta(days=1), ) -> BufferedAsyncData: """Create a new Streamlit buffer for a subscription. Args: project_id: Project ID for the topic and auto-created subscription topic: Topic for creating a subscription for max_messages: Maximum messages to buffer in streamlit that is shared between consumers. message_retention: Retention period in Pub/Sub for subscribers. It is recommended to set this at a minimum, i.e., 10 minutes. subscription_ttl: TTL for inactive subscription. Normally, the subscription should be automatically deleted at exit but the TTL can help cleanup unused subscriptions. It is recommended to set this to the minimum, i.e., 1 day. Returns: BufferedAsyncdata: Thread-safe shared buffer that Streamlit sessions can pull data from. """ # Create the subscription client and path sub = get_subscriber_client() # Create subscription ID subscription = f"{topic}-{str(uuid.uuid1())}" subscription_id = sub.subscription_path(project_id, subscription) # Initialize the durations subscription_ttl_dur = duration_pb2.Duration() subscription_ttl_dur.FromTimedelta(subscription_ttl) message_retention_dur = duration_pb2.Duration() message_retention_dur.FromTimedelta(message_retention) # Create the subscription logger.info("Creating subscription %s, topic %s", subscription_id, topic) sub.create_subscription( request=pubsub_v1.Subscription( name=subscription_id, topic=sub.topic_path(project_id, topic), message_retention_duration=message_retention_dur, expiration_policy=pubsub_v1.ExpirationPolicy( ttl=subscription_ttl_dur, ), ) ) # Create the AsyncData container and callback md = BufferedAsyncData(max_messages=max_messages) def callback(msg): md.update(msg.data) msg.ack() # Subscribe into the callback logger.info("Subscribing %s", subscription_id) fut = sub.subscribe(subscription_id, callback) # Shutdown of subscription at process termination def shutdown(): # Cancel the subscription logger.info("Stopping subscription %s", subscription_id) fut.cancel() # Delete the subscription logger.info("Deleting subscription %s", subscription_id) sub.delete_subscription( pubsub_v1.DeleteSubscriptionRequest(subscription=subscription_id) ) logger.info("Deleted subscription %s", subscription_id) atexit.register(shutdown) return md @st.cache_resource(show_spinner=False) def get_publisher_client(): # Create the subscription client and path return pubsub.PublisherClient( # User Agent allows us to track and prioritise # supporting this integration client_info=ClientInfo( user_agent="cloud-solutions/streamlit-pubsub-v1", ) ) @st.cache_resource(show_spinner=False) def get_publisher(project_id: str, topic: str): client = get_publisher_client() full_topic_id = client.topic_path(project_id, topic) def publisher(data): client.publish(topic=full_topic_id, data=data) return publisher