sdk/python/generative-ai/promptflow/deploy-flow/streaming/chatapp.py (179 lines of code) (raw):

from datetime import datetime import time import requests import sys import json from azure.identity import AzureCliCredential import logging from azure.ai.ml import MLClient from event_stream import EventStream class ColoredFormatter(logging.Formatter): # Color code dictionary color_codes = { "debug": "\033[0;32m", # Green "info": "\033[0;36m", # Cyan "warning": "\033[0;33m", # Yellow "error": "\033[0;31m", # Red "critical": "\033[0;35m", # Magenta } def format(self, record): # Get the original message message = super().format(record) # Add color codes message = ( f"{self.color_codes.get(record.levelname.lower(), '')}{message}\033[0m" ) return message logger = logging.getLogger(__name__) handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter()) logger.setLevel(logging.INFO) logger.addHandler(handler) def apply_delta(base: dict, delta: dict): for k, v in delta.items(): if k in base: base[k] += v else: base[k] = v def score(url, api_key, body, stream=True, on_event=None): headers = { "Content-Type": "application/json", "Authorization": ("Bearer " + api_key), # The azureml-model-deployment header will force the request to go to a specific deployment. # Remove this header to have the request observe the endpoint traffic rules "azureml-model-deployment": "keli19-chat-0727", "Accept": "text/event-stream, application/json" if stream else "application/json", } logger.info("Sending HTTP request...") logger.debug("POST %s", url) for name, value in headers.items(): if name == "Authorization": value = "[REDACTED]" logger.debug(f">>> {name}: {value}") logger.debug(json.dumps(body, indent=4, ensure_ascii=False)) logger.debug("") time1 = datetime.now() response = None try: response = requests.post(url, json=body, headers=headers, stream=stream) response.raise_for_status() finally: time2 = datetime.now() if response is not None: logger.info( "Got response: %d %s (elapsed %s)", response.status_code, response.reason, time2 - time1, ) for name, value in response.headers.items(): logger.debug(f"<<< {name}: {value}") time1 = datetime.now() try: content_type = response.headers.get("Content-Type") if "text/event-stream" in content_type: output = {} event_stream = EventStream(response.iter_lines()) for event in event_stream: if on_event: on_event(event) dct = json.loads(event.data) apply_delta(output, dct) return output, True else: return response.json(), False finally: time2 = datetime.now() logger.info("\nResponse reading elapsed: %s", time2 - time1) class ChatApp: def __init__( self, ml_client, endpoint_name, chat_input_name, chat_output_name, stream=True, debug=False, ): self._chat_input_name = chat_input_name self._chat_output_name = chat_output_name self._chat_history = [] self._stream = stream if debug: logger.setLevel(logging.DEBUG) logger.info("Getting endpoint info...") endpoint = ml_client.online_endpoints.get(endpoint_name) keys = ml_client.online_endpoints.get_keys(endpoint_name) self._endpoint_url = endpoint.scoring_uri self._endpoint_key = ( keys.primary_key if endpoint.auth_mode == "key" else keys.access_token ) logger.info(f"Done.") logger.debug(f"Target endpoint: {endpoint.id}") @property def url(self): return self._endpoint_url @property def api_key(self): return self._endpoint_key def get_payload(self, chat_input, chat_history=[]): return { self._chat_input_name: chat_input, "chat_history": chat_history, } def chat_once(self, chat_input): def on_event(event): dct = json.loads(event.data) answer_delta = dct.get(self._chat_output_name) if answer_delta: print(answer_delta, end="") # We need to flush the output # otherwise the text does not appear on the console # unless a new line comes. sys.stdout.flush() # Sleep for 20ms for better animation effects time.sleep(0.02) try: payload = self.get_payload( chat_input=chat_input, chat_history=self._chat_history ) output, stream = score( self.url, self.api_key, payload, stream=self._stream, on_event=on_event ) # We don't use self._stream here since the result may not always be the same as self._stream specified. if stream: # Print a new line at the end of the content to make sure # the next logger line will always starts from a new line. pass # print("\n") else: print(output.get(self._chat_output_name, "<empty>")) self._chat_history.append( { "inputs": { self._chat_input_name: chat_input, }, "outputs": output, } ) logger.info("Length of chat history: %s", len(self._chat_history)) except requests.HTTPError as e: logger.error(e.response.text) def chat(self): while True: try: question = input("Chat with Wikipedia:> ") if question in ("exit", "bye"): print("Bye.") break self.chat_once(question) except KeyboardInterrupt: # When pressed Ctrl_C, exit print("\nBye.") break except Exception as e: logger.exception("An error occurred: %s", e) # Do not raise the errors out so that we can continue the chat if __name__ == "__main__": ml_client = MLClient( credential=AzureCliCredential(), # TODO: Replace with your subscription ID, resource group name, and workspace name subscription_id="ee85ed72-2b26-48f6-a0e8-cb5bcf98fbd9", resource_group_name="keli19-aml", workspace_name="keli19-eastus", ) chat_app = ChatApp( ml_client=ml_client, # TODO: Replace with your online endpoint name endpoint_name="keli19-chat-0727", chat_input_name="question", chat_output_name="answer", stream=True, debug=True, ) chat_app.chat()