agents/core.py (67 lines of code) (raw):

""" Provides the base class for all Agents """ from abc import ABC import vertexai from google.cloud.aiplatform import telemetry from vertexai.language_models import TextGenerationModel from vertexai.language_models import CodeGenerationModel from vertexai.language_models import CodeChatModel from vertexai.generative_models import GenerativeModel from vertexai.generative_models import HarmCategory,HarmBlockThreshold from utilities import PROJECT_ID, PG_REGION vertexai.init(project=PROJECT_ID, location=PG_REGION) class Agent(ABC): """ The core class for all Agents """ agentType: str = "Agent" def __init__(self, model_id:str): """ model_id is the Model ID for initialization """ self.model_id = model_id if model_id == 'code-bison-32k': with telemetry.tool_context_manager('opendataqna'): self.model = CodeGenerationModel.from_pretrained('code-bison-32k') elif model_id == 'text-bison-32k': with telemetry.tool_context_manager('opendataqna'): self.model = TextGenerationModel.from_pretrained('text-bison-32k') elif model_id == 'codechat-bison-32k': with telemetry.tool_context_manager('opendataqna'): self.model = CodeChatModel.from_pretrained("codechat-bison-32k") elif model_id == 'gemini-1.0-pro': with telemetry.tool_context_manager('opendataqna'): # print("Model is gemini 1.0 pro") self.model = GenerativeModel("gemini-1.0-pro-001") self.safety_settings: Optional[dict] = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } elif model_id == 'gemini-1.5-flash': with telemetry.tool_context_manager('opendataqna'): # print("Model is gemini 1.5 flash") self.model = GenerativeModel("gemini-1.5-flash-preview-0514") self.safety_settings: Optional[dict] = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } elif model_id == 'gemini-1.5-pro': with telemetry.tool_context_manager('opendataqna'): # print("Model is gemini 1.5 Pro") self.model = GenerativeModel("gemini-1.5-pro-001") self.safety_settings: Optional[dict] = { HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, } else: raise ValueError("Please specify a compatible model.") def generate_llm_response(self,prompt): context_query = self.model.generate_content(prompt,safety_settings=self.safety_settings,stream=False) return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "").rstrip("\n") def rewrite_question(self,question,session_history): formatted_history='' concat_questions='' for i, _row in enumerate(session_history,start=1): user_question = _row['user_question'] # print(user_question) formatted_history += f"User Question - Turn :: {i} : {user_question}\n" concat_questions += f"{user_question} " # print(formatted_history) context_prompt = f""" Your main objective is to rewrite and refine the question based on the previous questions that has been asked. Refine the given question using the provided questions history to produce a standalone question with full context. The refined question should be self-contained, requiring no additional context for answering it. Make sure all the information is included in the re-written question. You just need to respond with the re-written question. Below is the previous questions history: {formatted_history} Question to rewrite: {question} """ re_written_qe = str(self.generate_llm_response(context_prompt)) print("*"*25 +"Re-written question:: "+"*"*25+"\n"+str(re_written_qe)) return str(concat_questions),str(re_written_qe)