src/jobs/util/language_model.py (116 lines of code) (raw):

import concurrent import json from enum import Enum from typing import List, TypedDict import logging import os import time import re import pandas as pd from openai import OpenAI class ModelType(Enum): OpenAIChat = "gpt-3.5-turbo" OpenAIDaVinci = "text-davinci-003" OpenAICurie = "text-curie-001" OpenAIBabbage = "text-babbage-001" OpenAIAda = "text-ada-001" OpenAI40 = "gpt-4o" class MessageSequenceItem(TypedDict): role: str content: str log_items = [] def stats_log(log_item): log_items.append(log_item) def dump_lang_logs(filename): with open(filename, "w") as ff: json.dump({"logs": log_items}, ff) class LanguageModel: """ Wrapper for OpenAI or similar language models """ def __init__(self, engine: ModelType = ModelType.OpenAI40, temperature=0.1, logger: logging.Logger = logging.getLogger()): self._engine = engine.value self._temperature = temperature self._logger = logger logger.setLevel(10) self._retry_delay_sec = 1 self._num_retries = 3 self.setup_openai() def setup_openai(self): self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) def cleanup_result(self, test: str) -> str: return test.lower().replace(".", "").strip() @classmethod def user_message(cls, text) -> MessageSequenceItem: return {"role": "user", "content": text} @classmethod def assistant_message(cls, text) -> MessageSequenceItem: return {"role": "assistant", "content": text} def generic_query(self, messages: List[MessageSequenceItem], retry_count=0, max_tokens=None, json_only=False) -> str: """ Do a basic openai query :param text: Prompt :return: Response, or empty string if there is a failure """ retry_allowance = retry_count response_format = {"type": "json_object"} if json_only else None while retry_allowance >= 0: retry_allowance -= 1 try: messages = [dict(m) for m in messages] url_response = self.client.chat.completions.create(model=self._engine, temperature=self._temperature, messages=messages, max_tokens=max_tokens, response_format=response_format) response = url_response.choices[0].message.content stats_log( {"message": messages[0]["content"][0:50], "response": response}) if self._logger is not None: self._logger.debug(f"OpenAI Response {response}") return response except Exception as ex: if self._logger is not None: self._logger.error(f"Error calling openai {ex}") if retry_count == 0: raise(ex) if self._retry_delay_sec > 0: time.sleep(self._retry_delay_sec) self._logger.error("Exceeded retry count. Returning empty string result") return "" def text_query(self, text: str, retry_count=0) -> str: """ Do a basic openai query :param text: Prompt :return: Response, or empty string if there is a failure """ return self.generic_query([self.user_message(text)], retry_count=retry_count) def ask_list_boolean(self, item_list, prompt, if_error=False, max_error=10, num_threads=4): """ Asks the same question prompt with a list of data objects :param item_list: :param prompt: :param if_error: :return: """ error_count = 0 def get_boolean_result(item): combined_query = f"{prompt} Answer 1 for yes and 0 for no on a single line without any additional comments. ###{item}###" response = self.text_query(combined_query, retry_count=5) res = self.cleanup_result(response) if res == "0": return False elif res == "1": return True else: self._logger.warning(f"Boolean request failed with response {response}") return False with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: results = list(executor.map(get_boolean_result, item_list)) return results def process_list_response(self, response): pattern = re.compile(r'(?:\d.)?(.+?)(?:\n|$)', re.DOTALL) # Extract the results using the regular expression pattern results = pattern.findall(response) return list(map(lambda a: a.strip(), results)) def list_query(self, query) -> List[str]: """ Use to call language model with a query that has a list of results. Typically, the query string should include 'Respond with each topic alone on a separate line' :return: a list """ openai_response = self.text_query(query) return self.process_list_response(openai_response) def ask_df(self, message: str, key_items: List[str]) -> pd.DataFrame: key_item_text = "" for key_item in key_items: key_item_text = key_item_text + f'"{key_item}": "<text-only>",\n' if len(key_item_text) > 0: key_item_text = key_item_text[:-1] json_pattern = "{ \"RESULT\": [ {" + key_item_text + "} , ...]" complete_message = f"{message} Respond in JSON list format to match the following exact pattern with exact keys: " \ f"{json_pattern}" response = self.generic_query([self.user_message(complete_message)], retry_count=2, json_only=True) print("Paring response") print(response) array = json.loads(response)["RESULT"] response = pd.DataFrame(data=array) return response def get_list(self, message_str): res_df = self.ask_df(message_str, ["list_item"]) return res_df["list_item"].to_list()