bot/code/bots/state_management_bot.py (87 lines of code) (raw):
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import time
from datetime import datetime
from botbuilder.core import ActivityHandler, ConversationState, TurnContext, UserState
from botbuilder.schema import ChannelAccount
from data_models import ConversationData, ConversationMessage
from botbuilder.azure import CosmosDbPartitionedStorage, CosmosDbPartitionedConfig
import json
import urllib
class StateManagementBot(ActivityHandler):
## Uncomment this for state management
def __init__(self, conversation_state: ConversationState, llm_endpoint: str, categories: str, organization_urls: list[str], organization: str, llm_api_key: str, welcome_message: str):
self.llm_endpoint = llm_endpoint
self.categories = categories
self.organization_urls = organization_urls
self.organization = organization
self.llm_api_key = llm_api_key
self.welcome_message = welcome_message
if conversation_state is None:
raise TypeError(
"[StateManagementBot]: Missing parameter. conversation_state is required but None was given"
)
self.conversation_state = conversation_state
self.conversation_data_accessor = self.conversation_state.create_property("ConversationData")
async def on_turn(self, turn_context: TurnContext):
await super().on_turn(turn_context)
await self.conversation_state.save_changes(turn_context)
async def on_members_added_activity(
self, members_added: [ChannelAccount], turn_context: TurnContext
):
for member in members_added:
if member.id != turn_context.activity.recipient.id:
await turn_context.send_activity(
self.welcome_message
)
async def on_message_activity(self, turn_context: TurnContext):
# Get the state properties from the turn context.
# user_profile = await self.user_profile_accessor.get(turn_context, UserProfile)
conversation_data = await self.conversation_data_accessor.get(
turn_context, ConversationData
)
# Add message details to the conversation data.
conversation_data.timestamp = self.__datetime_from_utc_to_local(
turn_context.activity.timestamp
)
session_id = turn_context.activity.conversation.id
user_id = turn_context.activity.from_property.id + "-" + turn_context.activity.channel_id
conversation_data.session_id = session_id
conversation_data.user_id = user_id
if conversation_data.messages is None:
conversation_data.messages = []
answer = self.call_llm(turn_context.activity.text, conversation_data.messages)
# print(answer)
message = dict()
# message = ConversationMessage()
message["inputs"] = {
"question": turn_context.activity.text,
"categories": self.categories,
"organization_urls": self.organization_urls,
"organization": self.organization,
}
message["outputs"] = json.loads(answer)
conversation_data.messages.append(message)
await turn_context.send_activity(
message["outputs"]["answer"]
)
def __datetime_from_utc_to_local(self, utc_datetime):
now_timestamp = time.time()
offset = datetime.fromtimestamp(now_timestamp) - datetime.utcfromtimestamp(
now_timestamp
)
result = utc_datetime + offset
return result.strftime("%I:%M:%S %p, %A, %B %d of %Y")
def call_llm(self, question, chat_history = []):
data = dict()
data["question"] = question
data["chat_history"] = chat_history
data["categories"] = self.categories
data["organization_urls"] = self.organization_urls
data["organization"] = self.organization
body = str.encode(json.dumps(data))
# Replace this with the primary/secondary key or AMLToken for the endpoint
api_key = self.llm_api_key
if not api_key:
raise Exception("A key should be provided to invoke the endpoint")
# The azureml-model-deployment header will force the request to go to a specific deployment.
# Remove this header to have the request observe the endpoint traffic rules
headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ api_key), 'azureml-model-deployment': 'chat-with-website-1' }
req = urllib.request.Request(self.llm_endpoint, body, headers)
try:
response = urllib.request.urlopen(req)
result = response.read()
# print(result)
return result
except urllib.error.HTTPError as error:
print("The request failed with status code: " + str(error.code))
# Print the headers - they include the requert ID and the timestamp, which are useful for debugging the failure
print(error.info())
print(error.read().decode("utf8", 'ignore'))