projects/streamlit-pubsub/streamlit_pubsub.py (138 lines of code) (raw):
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Streamlit and Pub/Sub integration methods.
These methods provide both subscription and publish functionality to
Streamlit. The subscription is much more sophisticated, as it provides a
central cache per-Subscription of received messages and allows multiple
Streamlit dashboards to share the same subscription data.
"""
import asyncio
import atexit
import datetime
import threading
import uuid
from google import pubsub_v1
from google.api_core.gapic_v1.client_info import ClientInfo
from google.cloud import pubsub
from google.protobuf import duration_pb2
from google.protobuf import timestamp_pb2
import streamlit as st
logger = st.logger.get_logger(__name__)
class Buffer:
"""Fixed size sequence-based buffer.
This fixed size buffer enables a subscription to add data,
as fast as possible, and then fetch (up to the buffer size) the most recent
N messages. A sequence number is maintained so the subscriber does not
get duplicates, but may miss messages.
"""
def __init__(self, max_elem=100):
self.max_elem = max_elem
self.data = []
self.next_idx = 0 # zero is front of queue
self.seq = -1 # zero is the first seqId (-1 is nothing)
def add(self, elem):
self.seq += 1
if self.next_idx == len(self.data):
self.data.append(elem)
else:
self.data[self.next_idx] = elem
self.next_idx += 1
if self.next_idx == self.max_elem:
self.next_idx = 0
def get_elems(self, last_seq):
"""Return all new messages since last_seq, up to max_elem."""
required_elem = self.seq - last_seq
if required_elem > self.max_elem:
required_elem = self.max_elem
if required_elem <= 0:
return (self.seq, [])
if required_elem <= self.next_idx:
start_idx = self.next_idx - required_elem
return (self.seq, self.data[start_idx : self.next_idx])
start_end_idx = len(self.data) - required_elem + self.next_idx
return (
self.seq,
self.data[start_end_idx:] + self.data[: self.next_idx],
)
class BufferedAsyncData:
"""Thread-safe Buffer wrapper with additional async methods.
This enables sharing the buffer between threads and exposing it into
Streamlit dashboards.
"""
def __init__(self, max_messages=100):
self.buf = Buffer(max_elem=max_messages)
self.seq_id = 0
self.cv = threading.Condition()
def update(self, new_data: bytes):
with self.cv:
self.buf.add(new_data)
self.cv.notify_all()
def get_latest_data(self, last_seq_id: int = -1, timeout: float = None):
"""Thread-safe fetch of the latest data since last_seq_id.
The timeout is used to return early if specified (empty list).
Args:
last_seq_id: The last sequence id (highest) that was fetched.
timeout: The time to wait before returning empty list.
Returns:
A tuple of new high seq_id and a list of data. If timeout was
specified, then the high seq_id may be the same as last_seq_id and the
list of data empty.
"""
while True:
with self.cv:
seq_id, data = self.buf.get_elems(last_seq_id)
# If data is found, return it as is
if data:
return seq_id, data
# If timed out, just return the empty data
if not self.cv.wait(timeout=timeout):
return seq_id, data
async def aget_latest_data(
self, last_seq_id: int = -1, timeout: float = None
):
return await asyncio.to_thread(
self.get_latest_data, last_seq_id, timeout
)
async def aget_latest_st_data(self, seq_id_key="key", timeout: float = 1.0):
"""Get the latest data as Streamlit, using the seq_id_key.
This is designed for a Streamlit dashboard using seq_id_key, with
a frequent timeout.
Args:
seq_id_key: Key in the st.session_state for storing the seq_id.
timeout: Timeout between loops. Recommended to be fairly small.
Returns:
List of data when there is new data to process.
"""
if seq_id_key not in st.session_state:
st.session_state[seq_id_key] = -1
while True:
(st.session_state[seq_id_key], data) = await self.aget_latest_data(
last_seq_id=st.session_state[seq_id_key], timeout=timeout
)
# If no data (e.g., timeout occurred), just try again
if data:
return data
# Capture script startup time -- if seeking to now
START_TIMESTAMP = timestamp_pb2.Timestamp()
START_TIMESTAMP.GetCurrentTime()
@st.cache_resource(show_spinner=False)
def get_subscriber_client():
# Create the subscription client and path
return pubsub.SubscriberClient(
# User Agent allows us to track and prioritise
# supporting this integration
client_info=ClientInfo(
user_agent="cloud-solutions/streamlit-pubsub-v1",
)
)
@st.cache_resource(show_spinner=True)
def get_subscriber(
project_id: str,
topic: str,
max_messages: int,
message_retention: datetime.timedelta = datetime.timedelta(minutes=10),
subscription_ttl: datetime.timedelta = datetime.timedelta(days=1),
) -> BufferedAsyncData:
"""Create a new Streamlit buffer for a subscription.
Args:
project_id: Project ID for the topic and auto-created subscription
topic: Topic for creating a subscription for
max_messages: Maximum messages to buffer in streamlit that is
shared between consumers.
message_retention: Retention period in Pub/Sub for
subscribers. It is recommended to set this at a minimum, i.e., 10
minutes.
subscription_ttl: TTL for inactive subscription. Normally,
the subscription should be automatically deleted at exit but the TTL
can help cleanup unused subscriptions. It is recommended to set this
to the minimum, i.e., 1 day.
Returns:
BufferedAsyncdata: Thread-safe shared buffer that Streamlit sessions can
pull data from.
"""
# Create the subscription client and path
sub = get_subscriber_client()
# Create subscription ID
subscription = f"{topic}-{str(uuid.uuid1())}"
subscription_id = sub.subscription_path(project_id, subscription)
# Initialize the durations
subscription_ttl_dur = duration_pb2.Duration()
subscription_ttl_dur.FromTimedelta(subscription_ttl)
message_retention_dur = duration_pb2.Duration()
message_retention_dur.FromTimedelta(message_retention)
# Create the subscription
logger.info("Creating subscription %s, topic %s", subscription_id, topic)
sub.create_subscription(
request=pubsub_v1.Subscription(
name=subscription_id,
topic=sub.topic_path(project_id, topic),
message_retention_duration=message_retention_dur,
expiration_policy=pubsub_v1.ExpirationPolicy(
ttl=subscription_ttl_dur,
),
)
)
# Create the AsyncData container and callback
md = BufferedAsyncData(max_messages=max_messages)
def callback(msg):
md.update(msg.data)
msg.ack()
# Subscribe into the callback
logger.info("Subscribing %s", subscription_id)
fut = sub.subscribe(subscription_id, callback)
# Shutdown of subscription at process termination
def shutdown():
# Cancel the subscription
logger.info("Stopping subscription %s", subscription_id)
fut.cancel()
# Delete the subscription
logger.info("Deleting subscription %s", subscription_id)
sub.delete_subscription(
pubsub_v1.DeleteSubscriptionRequest(subscription=subscription_id)
)
logger.info("Deleted subscription %s", subscription_id)
atexit.register(shutdown)
return md
@st.cache_resource(show_spinner=False)
def get_publisher_client():
# Create the subscription client and path
return pubsub.PublisherClient(
# User Agent allows us to track and prioritise
# supporting this integration
client_info=ClientInfo(
user_agent="cloud-solutions/streamlit-pubsub-v1",
)
)
@st.cache_resource(show_spinner=False)
def get_publisher(project_id: str, topic: str):
client = get_publisher_client()
full_topic_id = client.topic_path(project_id, topic)
def publisher(data):
client.publish(topic=full_topic_id, data=data)
return publisher