agents/core.py (67 lines of code) (raw):
"""
Provides the base class for all Agents
"""
from abc import ABC
import vertexai
from google.cloud.aiplatform import telemetry
from vertexai.language_models import TextGenerationModel
from vertexai.language_models import CodeGenerationModel
from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel
from vertexai.generative_models import HarmCategory,HarmBlockThreshold
from utilities import PROJECT_ID, PG_REGION
vertexai.init(project=PROJECT_ID, location=PG_REGION)
class Agent(ABC):
"""
The core class for all Agents
"""
agentType: str = "Agent"
def __init__(self,
model_id:str):
"""
model_id is the Model ID for initialization
"""
self.model_id = model_id
if model_id == 'code-bison-32k':
with telemetry.tool_context_manager('opendataqna'):
self.model = CodeGenerationModel.from_pretrained('code-bison-32k')
elif model_id == 'text-bison-32k':
with telemetry.tool_context_manager('opendataqna'):
self.model = TextGenerationModel.from_pretrained('text-bison-32k')
elif model_id == 'codechat-bison-32k':
with telemetry.tool_context_manager('opendataqna'):
self.model = CodeChatModel.from_pretrained("codechat-bison-32k")
elif model_id == 'gemini-1.0-pro':
with telemetry.tool_context_manager('opendataqna'):
# print("Model is gemini 1.0 pro")
self.model = GenerativeModel("gemini-1.0-pro-001")
self.safety_settings: Optional[dict] = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
elif model_id == 'gemini-1.5-flash':
with telemetry.tool_context_manager('opendataqna'):
# print("Model is gemini 1.5 flash")
self.model = GenerativeModel("gemini-1.5-flash-preview-0514")
self.safety_settings: Optional[dict] = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
elif model_id == 'gemini-1.5-pro':
with telemetry.tool_context_manager('opendataqna'):
# print("Model is gemini 1.5 Pro")
self.model = GenerativeModel("gemini-1.5-pro-001")
self.safety_settings: Optional[dict] = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
else:
raise ValueError("Please specify a compatible model.")
def generate_llm_response(self,prompt):
context_query = self.model.generate_content(prompt,safety_settings=self.safety_settings,stream=False)
return str(context_query.candidates[0].text).replace("```sql", "").replace("```", "").rstrip("\n")
def rewrite_question(self,question,session_history):
formatted_history=''
concat_questions=''
for i, _row in enumerate(session_history,start=1):
user_question = _row['user_question']
# print(user_question)
formatted_history += f"User Question - Turn :: {i} : {user_question}\n"
concat_questions += f"{user_question} "
# print(formatted_history)
context_prompt = f"""
Your main objective is to rewrite and refine the question based on the previous questions that has been asked.
Refine the given question using the provided questions history to produce a standalone question with full context. The refined question should be self-contained, requiring no additional context for answering it.
Make sure all the information is included in the re-written question. You just need to respond with the re-written question.
Below is the previous questions history:
{formatted_history}
Question to rewrite:
{question}
"""
re_written_qe = str(self.generate_llm_response(context_prompt))
print("*"*25 +"Re-written question:: "+"*"*25+"\n"+str(re_written_qe))
return str(concat_questions),str(re_written_qe)