packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-components/backend/python/app/stream.py (112 lines of code) (raw):

import json import logging from typing import Any, Callable, Generator, Optional from anthropic.types import ContentBlockDeltaEvent, MessageDeltaEvent, MessageStopEvent from app.bedrock import calculate_price, get_bedrock_response, get_model_id from app.routes.schemas.conversation import type_model_name from app.utils import get_anthropic_client, is_anthropic_model from langchain_core.outputs import GenerationChunk from pydantic import BaseModel logger = logging.getLogger(__name__) def get_stream_handler_type(model: type_model_name): model_id = get_model_id(model) if is_anthropic_model(model_id): return AnthropicStreamHandler else: return BedrockStreamHandler class OnStopInput(BaseModel): full_token: str stop_reason: str input_token_count: int output_token_count: int price: float class BaseStreamHandler: def __init__( self, model: type_model_name, on_stream: Callable[[str], GenerationChunk | None], on_stop: Callable[[OnStopInput], GenerationChunk | None], ): """Base class for stream handlers. :param model: Model name. :param on_stream: Callback function for streaming. :param on_stop: Callback function for stopping the stream. """ self.model = model self.on_stream = on_stream self.on_stop = on_stop def run(self, args: dict): raise NotImplementedError() @classmethod def from_model(cls, model: type_model_name): return get_stream_handler_type(model)( model=model, on_stream=lambda x: None, on_stop=lambda x: None ) def bind( self, on_stream: Callable[[str], Any], on_stop: Callable[[OnStopInput], Any] ): self.on_stream = on_stream self.on_stop = on_stop return self class AnthropicStreamHandler(BaseStreamHandler): """Stream handler for Anthropic models.""" def run(self, args: dict): client = get_anthropic_client() response = client.messages.create(**args) completions = [] stop_reason = "" for event in response: # NOTE: following is the example of event sequence: # MessageStartEvent(message=Message(id='compl_01GwmkwncsptaeBopeaR4eWE', content=[], model='claude-instant-1.2', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=21, output_tokens=1)), type='message_start') # ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start') # ... # ContentBlockDeltaEvent(delta=TextDelta(text='です', type='text_delta'), index=0, type='content_block_delta') # ContentBlockStopEvent(index=0, type='content_block_stop') # MessageDeltaEvent(delta=Delta(stop_reason='end_turn', stop_sequence=None), type='message_delta', usage=MessageDeltaUsage(output_tokens=26)) # MessageStopEvent(type='message_stop', amazon-bedrock-invocationMetrics={'inputTokenCount': 21, 'outputTokenCount': 25, 'invocationLatency': 621, 'firstByteLatency': 279}) if isinstance(event, ContentBlockDeltaEvent): completions.append(event.delta.text) response = self.on_stream(event.delta.text) yield response elif isinstance(event, MessageDeltaEvent): logger.debug(f"Received message delta event: {event.delta}") stop_reason = str(event.delta.stop_reason) elif isinstance(event, MessageStopEvent): concatenated = "".join(completions) metrics = event.model_dump()["amazon-bedrock-invocationMetrics"] input_token_count = metrics.get("inputTokenCount") output_token_count = metrics.get("outputTokenCount") price = calculate_price( self.model, input_token_count, output_token_count ) response = self.on_stop( OnStopInput( full_token=concatenated, stop_reason=stop_reason, input_token_count=input_token_count, output_token_count=output_token_count, price=price, ) ) yield response else: continue class BedrockStreamHandler(BaseStreamHandler): """Stream handler for Bedrock models (e.g. Mistral).""" def run(self, args: dict): response = get_bedrock_response(args) completions = [] stop_reason = "" for event in response.get("body"): # type: ignore chunk = event.get("chunk") if chunk: msg_chunk = json.loads(chunk.get("bytes").decode()) stop_reason = msg_chunk["outputs"][0]["stop_reason"] if not stop_reason: msg: str = msg_chunk["outputs"][0]["text"] completions.append(msg) res = self.on_stream(msg) yield res else: concatenated = "".join(completions) metrics = msg_chunk["amazon-bedrock-invocationMetrics"] input_token_count = metrics.get("inputTokenCount") output_token_count = metrics.get("outputTokenCount") price = calculate_price( self.model, input_token_count, output_token_count ) res = self.on_stop( OnStopInput( full_token=concatenated, stop_reason=stop_reason, input_token_count=input_token_count, output_token_count=output_token_count, price=price, ) ) yield res