agents/VisualizeAgent.py (52 lines of code) (raw):
#This agent generates google charts code for displaying charts on web application
#Generates two charts with elements "chart-div" and "chart-div-1"
#Code is in javascript
from abc import ABC
from vertexai.language_models import CodeChatModel
from vertexai.generative_models import GenerativeModel,HarmCategory,HarmBlockThreshold
from .core import Agent
from utilities import PROMPTS, format_prompt
from agents import ValidateSQLAgent
import pandas as pd
import json
from google.cloud.aiplatform import telemetry
import vertexai
from utilities import PROJECT_ID, PG_REGION
vertexai.init(project=PROJECT_ID, location=PG_REGION)
class VisualizeAgent(Agent, ABC):
"""
An agent that generates JavaScript code for Google Charts based on user questions and SQL results.
This agent analyzes the user's question and the corresponding SQL query results to determine suitable chart types. It then constructs JavaScript code that uses Google Charts to create visualizations based on the data.
Attributes:
agentType (str): Indicates the type of agent, fixed as "VisualizeAgent".
model_id (str): The ID of the language model used for chart type suggestion and code generation.
model: The language model instance.
Methods:
getChartType(user_question, generated_sql) -> str:
Suggests the two most suitable chart types based on the user's question and the generated SQL query.
Args:
user_question (str): The natural language question asked by the user.
generated_sql (str): The SQL query generated to answer the question.
Returns:
str: A JSON string containing two keys, "chart_1" and "chart_2", each representing a suggested chart type.
getChartPrompt(user_question, generated_sql, chart_type, chart_div, sql_results) -> str:
Creates a prompt for the language model to generate the JavaScript code for a specific chart.
Args:
user_question (str): The user's question.
generated_sql (str): The generated SQL query.
chart_type (str): The desired chart type (e.g., "Bar Chart", "Pie Chart").
chart_div (str): The HTML element ID where the chart will be rendered.
sql_results (str): The results of the SQL query in JSON format.
Returns:
str: The prompt for the language model to generate the chart code.
generate_charts(user_question, generated_sql, sql_results) -> dict:
Generates JavaScript code for two Google Charts based on the given inputs.
Args:
user_question (str): The user's question.
generated_sql (str): The generated SQL query.
sql_results (str): The results of the SQL query in JSON format.
Returns:
dict: A dictionary containing two keys, "chart_div" and "chart_div_1", each holding the generated JavaScript code for a chart.
"""
agentType: str ="VisualizeAgent"
def __init__(self):
self.model_id = 'gemini-1.5-pro'
self.model = GenerativeModel("gemini-1.5-pro-001")
def getChartType(self,user_question, generated_sql):
chart_type_prompt = PROMPTS['visualize_chart_type']
chart_type_prompt = format_prompt(chart_type_prompt,
user_question = user_question,
generated_sql = generated_sql)
chart_type=self.model.generate_content(chart_type_prompt, stream=False).candidates[0].text
# print(chart_type)
# chart_type = model.predict(map_prompt, max_output_tokens = 1024, temperature= 0.2).candidates[0].text
return chart_type.replace("\n", "").replace("```", "").replace("json", "").replace("```html", "").replace("```", "").replace("js\n","").replace("json\n","").replace("python\n","").replace("javascript","")
def getChartPrompt(self,user_question, generated_sql, chart_type, chart_div, sql_results):
chart_prompt = PROMPTS['visualize_generate_chart_code']
chart_prompt = format_prompt(chart_prompt,
user_question = user_question,
generated_sql = generated_sql,
chart_type = chart_type,
chart_div = chart_div,
sql_results = sql_results)
# print(f"Prompt to generate code for google charts visualization after formatting: \n{chart_prompt}")
return chart_prompt
def generate_charts(self,user_question,generated_sql,sql_results):
chart_type = self.getChartType(user_question,generated_sql)
# chart_type = chart_type.split(",")
# chart_list = [x.strip() for x in chart_type]
chart_json = json.loads(chart_type)
chart_list =[chart_json['chart_1'],chart_json['chart_2']]
print("Charts Suggested : " + str(chart_list))
context_prompt=self.getChartPrompt(user_question,generated_sql,chart_list[0],"chart_div",sql_results)
context_prompt_1=self.getChartPrompt(user_question,generated_sql,chart_list[1],"chart_div_1",sql_results)
safety_settings: Optional[dict] = {
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
with telemetry.tool_context_manager('opendataqna-visualize-v2'):
context_query = self.model.generate_content(context_prompt,safety_settings=safety_settings, stream=False)
context_query_1 = self.model.generate_content(context_prompt_1,safety_settings=safety_settings, stream=False)
google_chart_js={"chart_div":context_query.candidates[0].text.replace("```json", "").replace("```", "").replace("json", "").replace("```html", "").replace("```", "").replace("js","").replace("json","").replace("python","").replace("javascript",""),
"chart_div_1":context_query_1.candidates[0].text.replace("```json", "").replace("```", "").replace("json", "").replace("```html", "").replace("```", "").replace("js","").replace("json","").replace("python","").replace("javascript","")}
return google_chart_js