in packages/blueprints/gen-ai-chatbot/static-assets/chatbot-genai-components/backend/python/app/websocket.py [0:0]
def handler(event, context):
logger.info(f"Received event: {event}")
route_key = event["requestContext"]["routeKey"]
if route_key == "$connect":
return {"statusCode": 200, "body": "Connected."}
elif route_key == "$disconnect":
return {"statusCode": 200, "body": "Disconnected."}
connection_id = event["requestContext"]["connectionId"]
domain_name = event["requestContext"]["domainName"]
stage = event["requestContext"]["stage"]
endpoint_url = f"https://{domain_name}/{stage}"
gatewayapi = boto3.client("apigatewaymanagementapi", endpoint_url=endpoint_url)
now = datetime.now()
expire = int(now.timestamp()) + 60 * 2 # 2 minute from now
body = json.loads(event["body"])
step = body.get("step")
try:
# API Gateway (websocket) has hard limit of 32KB per message, so if the message is larger than that,
# need to concatenate chunks and send as a single full message.
# To do that, we store the chunks in DynamoDB and when the message is complete, send it to SNS.
# The life cycle of the message is as follows:
# 1. Client sends `START` message to the WebSocket API.
# 2. This handler receives the `Session started` message.
# 3. Client sends message parts to the WebSocket API.
# 4. This handler receives the message parts and appends them to the item in DynamoDB with index.
# 5. Client sends `END` message to the WebSocket API.
# 6. This handler receives the `END` message, concatenates the parts and sends the message to Bedrock.
if step == "START":
token = body["token"]
try:
# Verify JWT token
decoded = verify_token(token)
except Exception as e:
logger.error(f"Invalid token: {e}")
return {"statusCode": 403, "body": "Invalid token."}
user_id = decoded["sub"]
# Store user id
response = table.put_item(
Item={
"ConnectionId": connection_id,
# Store as zero
"MessagePartId": decimal(0),
"UserId": user_id,
"expire": expire,
}
)
return {"statusCode": 200, "body": "Session started."}
elif step == "END":
# Retrieve user id
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id),
FilterExpression=Attr("UserId").exists(),
)
user_id = response["Items"][0]["UserId"]
# Concatenate the message parts
message_parts = []
last_evaluated_key = None
while True:
if last_evaluated_key:
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id)
# Zero is reserved for user id, so start from 1
& Key("MessagePartId").gte(1),
ExclusiveStartKey=last_evaluated_key,
)
else:
response = table.query(
KeyConditionExpression=Key("ConnectionId").eq(connection_id)
& Key("MessagePartId").gte(1),
)
message_parts.extend(response["Items"])
if "LastEvaluatedKey" in response:
last_evaluated_key = response["LastEvaluatedKey"]
else:
break
logger.info(f"Number of message chunks: {len(message_parts)}")
message_parts.sort(key=lambda x: x["MessagePartId"])
full_message = "".join(item["MessagePart"] for item in message_parts)
# Process the concatenated full message
chat_input = ChatInput(**json.loads(full_message))
return process_chat_input(
user_id=user_id,
chat_input=chat_input,
gatewayapi=gatewayapi,
connection_id=connection_id,
)
else:
# Store the message part of full message
# Zero is reserved for user id, so start from 1
part_index = body["index"] + 1
message_part = body["part"]
# Store the message part with its index
table.put_item(
Item={
"ConnectionId": connection_id,
"MessagePartId": decimal(part_index),
"MessagePart": message_part,
"expire": expire,
}
)
return {"statusCode": 200, "body": "Message part received."}
except Exception as e:
logger.error(f"Operation failed: {e}")
logger.error("".join(traceback.format_tb(e.__traceback__)))
gatewayapi.post_to_connection(
ConnectionId=connection_id,
Data=json.dumps({"status": "ERROR", "reason": str(e)}).encode("utf-8"),
)
return {"statusCode": 500, "body": str(e)}