packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-components/backend/python/app/websocket.py (318 lines of code) (raw):
import json
import logging
import os
import traceback
from datetime import datetime
from decimal import Decimal as decimal
import boto3
from app.agents.agent import AgentExecutor, create_react_agent, format_log_to_str
from app.agents.handlers.apigw_websocket import ApigwWebsocketCallbackHandler
from app.agents.handlers.token_count import get_token_count_callback
from app.agents.handlers.used_chunk import get_used_chunk_callback
from app.agents.langchain import BedrockLLM
from app.agents.tools.knowledge import AnswerWithKnowledgeTool
from app.agents.utils import get_tool_by_name
from app.auth import verify_token
from app.bedrock import compose_args
from app.repositories.conversation import RecordNotFoundError, store_conversation
from app.repositories.models.conversation import ChunkModel, ContentModel, MessageModel
from app.routes.schemas.conversation import ChatInput
from app.stream import OnStopInput, get_stream_handler_type
from app.usecases.bot import modify_bot_last_used_time
from app.usecases.chat import insert_knowledge, prepare_conversation, trace_to_root
from app.utils import get_anthropic_client, get_current_time, is_anthropic_model
from app.vector_search import filter_used_results, get_source_link, search_related_docs
from boto3.dynamodb.conditions import Attr, Key
from ulid import ULID
WEBSOCKET_SESSION_TABLE_NAME = os.environ["WEBSOCKET_SESSION_TABLE_NAME"]
client = get_anthropic_client()
dynamodb_client = boto3.resource("dynamodb")
table = dynamodb_client.Table(WEBSOCKET_SESSION_TABLE_NAME)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
def process_chat_input(
user_id: str, chat_input: ChatInput, gatewayapi, connection_id: str
) -> dict:
"""Process chat input and send the message to the client."""
logger.info(f"Received chat input: {chat_input}")
try:
user_msg_id, conversation, bot = prepare_conversation(user_id, chat_input)
except RecordNotFoundError:
if chat_input.bot_id:
gatewayapi.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(
dict(
status="ERROR",
reason="bot_not_found",
)
).encode("utf-8"),
)
return {"statusCode": 404, "body": f"bot {chat_input.bot_id} not found."}
else:
return {"statusCode": 400, "body": "Invalid request."}
if bot and bot.is_agent_enabled():
logger.info("Bot has agent tools. Using agent for response.")
llm = BedrockLLM.from_model(model=chat_input.message.model)
tools = [get_tool_by_name(t.name) for t in bot.agent.tools]
if bot and bot.has_knowledge():
logger.info("Bot has knowledge. Adding answer with knowledge tool.")
answer_with_knowledge_tool = AnswerWithKnowledgeTool.from_bot(
bot=bot,
llm=llm,
)
tools.append(answer_with_knowledge_tool)
logger.info(f"Tools: {tools}")
agent = create_react_agent(
model=chat_input.message.model,
tools=tools,
generation_config=bot.generation_params,
)
executor = AgentExecutor(
name="Agent Executor",
agent=agent,
tools=tools,
return_intermediate_steps=True,
callbacks=[],
verbose=False,
max_iterations=15,
max_execution_time=None,
early_stopping_method="force",
handle_parsing_errors=True,
)
price = 0.0
used_chunks = None
thinking_log = None
with get_token_count_callback() as token_cb, get_used_chunk_callback() as chunk_cb:
response = executor.invoke(
{
"input": chat_input.message.content[0].body,
},
config={
"callbacks": [
ApigwWebsocketCallbackHandler(gatewayapi, connection_id),
token_cb,
chunk_cb,
],
},
)
price = token_cb.total_cost
if bot.display_retrieved_chunks and chunk_cb.used_chunks:
used_chunks = chunk_cb.used_chunks
thinking_log = format_log_to_str(response.get("intermediate_steps", []))
logger.info(f"Thinking log: {thinking_log}")
# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(
content_type="text", body=response["output"], media_type=None
)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=thinking_log,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
conversation.total_price += price
# Store conversation before finish streaming so that front-end can avoid 404 issue
store_conversation(user_id, conversation)
# Send signal so that frontend can close the connection
last_data_to_send = json.dumps(
dict(status="STREAMING_END", completion="", stop_reason="agent_finish")
).encode("utf-8")
gatewayapi.post_to_connection(
ConnectionId=connection_id, Data=last_data_to_send
)
return {"statusCode": 200, "body": "Message sent."}
message_map = conversation.message_map
search_results = []
if bot and bot.has_knowledge():
gatewayapi.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps(
dict(
status="FETCHING_KNOWLEDGE",
)
).encode("utf-8"),
)
# Fetch most related documents from vector store
# NOTE: Currently embedding not support multi-modal. For now, use the last text content.
query = conversation.message_map[user_msg_id].content[-1].body
search_results = search_related_docs(
bot_id=bot.id, limit=bot.search_params.max_results, query=query
)
logger.info(f"Search results from vector store: {search_results}")
# Insert contexts to instruction
conversation_with_context = insert_knowledge(
conversation, search_results, display_citation=bot.display_retrieved_chunks
)
message_map = conversation_with_context.message_map
messages = trace_to_root(
node_id=chat_input.message.parent_message_id,
message_map=message_map,
)
messages.append(chat_input.message) # type: ignore
args = compose_args(
messages,
chat_input.message.model,
instruction=(
message_map["instruction"].content[0].body
if "instruction" in message_map
else None
),
stream=True,
generation_params=(bot.generation_params if bot else None),
)
def on_stream(token: str, **kwargs) -> None:
# Send completion
data_to_send = json.dumps(dict(status="STREAMING", completion=token)).encode(
"utf-8"
)
gatewayapi.post_to_connection(ConnectionId=connection_id, Data=data_to_send)
def on_stop(arg: OnStopInput, **kwargs) -> None:
used_chunks = None
if bot and bot.display_retrieved_chunks:
if len(search_results) > 0:
used_chunks = []
for r in filter_used_results(arg.full_token, search_results):
content_type, source_link = get_source_link(r.source)
used_chunks.append(
ChunkModel(
content=r.content,
content_type=content_type,
source=source_link,
rank=r.rank,
)
)
# Append entire completion as the last message
assistant_msg_id = str(ULID())
message = MessageModel(
role="assistant",
content=[
ContentModel(content_type="text", body=arg.full_token, media_type=None)
],
model=chat_input.message.model,
children=[],
parent=user_msg_id,
create_time=get_current_time(),
feedback=None,
used_chunks=used_chunks,
thinking_log=None,
)
conversation.message_map[assistant_msg_id] = message
# Append children to parent
conversation.message_map[user_msg_id].children.append(assistant_msg_id)
conversation.last_message_id = assistant_msg_id
conversation.total_price += arg.price
# Store conversation before finish streaming so that front-end can avoid 404 issue
store_conversation(user_id, conversation)
last_data_to_send = json.dumps(
dict(status="STREAMING_END", completion="", stop_reason=arg.stop_reason)
).encode("utf-8")
gatewayapi.post_to_connection(
ConnectionId=connection_id, Data=last_data_to_send
)
stream_handler = get_stream_handler_type(chat_input.message.model)(
model=chat_input.message.model,
on_stream=on_stream,
on_stop=on_stop,
)
try:
for _ in stream_handler.run(args):
# `StreamHandler.run` returns a generator, so need to iterate
...
except Exception as e:
logger.error(f"Failed to run stream handler: {e}")
return {
"statusCode": 500,
"body": "Failed to run stream handler.",
}
# Update bot last used time
if chat_input.bot_id:
logger.info("Bot id is provided. Updating bot last used time.")
modify_bot_last_used_time(user_id, chat_input.bot_id)
return {"statusCode": 200, "body": "Message sent."}
def handler(event, context):
logger.info(f"Received event: {event}")
route_key = event["requestContext"]["routeKey"]
if route_key == "$connect":
return {"statusCode": 200, "body": "Connected."}
elif route_key == "$disconnect":
return {"statusCode": 200, "body": "Disconnected."}
connection_id = event["requestContext"]["connectionId"]
domain_name = event["requestContext"]["domainName"]
stage = event["requestContext"]["stage"]
endpoint_url = f"https://{domain_name}/{stage}"
gatewayapi = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url)
now = datetime.now()
expire = int(now.timestamp()) + 60 * 2 # 2 minute from now
body = json.loads(event["body"])
step = body.get("step")
try:
# API Gateway (websocket) has hard limit of 32KB per message, so if the message is larger than that,
# need to concatenate chunks and send as a single full message.
# To do that, we store the chunks in DynamoDB and when the message is complete, send it to SNS.
# The life cycle of the message is as follows:
# 1. Client sends `START` message to the WebSocket API.
# 2. This handler receives the `Session started` message.
# 3. Client sends message parts to the WebSocket API.
# 4. This handler receives the message parts and appends them to the item in DynamoDB with index.
# 5. Client sends `END` message to the WebSocket API.
# 6. This handler receives the `END` message, concatenates the parts and sends the message to Bedrock.
if step == "START":
token = body["token"]
try:
# Verify JWT token
decoded = verify_token(token)
except Exception as e:
logger.error(f"Invalid token: {e}")
return {"statusCode": 403, "body": "Invalid token."}
user_id = decoded["sub"]
# Store user id
response = table.put_item(
Item={
"ConnectionId": connection_id,
# Store as zero
"MessagePartId": decimal(0),
"UserId": user_id,
"expire": expire,
}
)
return {"statusCode": 200, "body": "Session started."}
elif step == "END":
# Retrieve user id
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id),
FilterExpression=Attr("UserId").exists(),
)
user_id = response["Items"][0]["UserId"]
# Concatenate the message parts
message_parts = []
last_evaluated_key = None
while True:
if last_evaluated_key:
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id)
# Zero is reserved for user id, so start from 1
& Key("MessagePartId").gte(1),
ExclusiveStartKey=last_evaluated_key,
)
else:
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id)
& Key("MessagePartId").gte(1),
)
message_parts.extend(response["Items"])
if "LastEvaluatedKey" in response:
last_evaluated_key = response["LastEvaluatedKey"]
else:
break
logger.info(f"Number of message chunks: {len(message_parts)}")
message_parts.sort(key=lambda x: x["MessagePartId"])
full_message = "".join(item["MessagePart"] for item in message_parts)
# Process the concatenated full message
chat_input = ChatInput(**json.loads(full_message))
return process_chat_input(
user_id=user_id,
chat_input=chat_input,
gatewayapi=gatewayapi,
connection_id=connection_id,
)
else:
# Store the message part of full message
# Zero is reserved for user id, so start from 1
part_index = body["index"] + 1
message_part = body["part"]
# Store the message part with its index
table.put_item(
Item={
"ConnectionId": connection_id,
"MessagePartId": decimal(part_index),
"MessagePart": message_part,
"expire": expire,
}
)
return {"statusCode": 200, "body": "Message part received."}
except Exception as e:
logger.error(f"Operation failed: {e}")
logger.error("".join(traceback.format_tb(e.__traceback__)))
gatewayapi.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps({"status": "ERROR", "reason": str(e)}).encode("utf-8"),
)
return {"statusCode": 500, "body": str(e)}