UI/evaluation/eval.py (208 lines of code) (raw):
"""
Evaluation main code for different NL2SQL frameworks
"""
import os
import json
import time
import requests
import pandas as pd
# from langchain_google_vertexai import VertexAI
from google.cloud import bigquery
import streamlit as st
from loguru import logger
from dbai_src.dbai import DBAI_nl2sql
LITE_API_PART = 'lite'
FEW_SHOT_GENERATION = "Few Shot"
GEN_BY_CORE = "CORE_EXECUTORS"
GEN_BY_LITE = "LITE_EXECUTORS"
os.environ[GEN_BY_LITE] =\
"https://nl2sqlstudio-lite-prod-dot-sl-test-project-363109\
.uc.r.appspot.com"
os.environ[GEN_BY_CORE] =\
"https://nl2sqlexecutors-prod-dot-sl-test-project-363109.uc.r.appspot.com"
ENDPOINTS = {
"Few Shot": "/api/lite/generate",
"Linear Executor": "/api/executor/linear",
"Rag Executor": "/api/executor/rag",
"COT": "/api/executor/cot"
}
params = dict(
execution=False,
lite_model=FEW_SHOT_GENERATION,
access_token=""
)
# llm = VertexAI(temperature=0,
# model_name="gemini-1.5-pro-001",
# max_output_tokens=1024)
# def auto_verify(nl_description, ground_truth, llm_amswer):
# """
# This function verifies the accuracy of SQL query based on a
# natural language description and a ground truth query, using
# text-bison model.
# Parameters:
# - nl_description (str): The natural language description of the SQL
# query.
# - ground_truth (str): The ground truth SQL query.
# - llm_amswer (str): The student's generated SQL query for validation.
# Returns:
# str: "Yes" if the student's answer matches the ground truth and fits
# the NL description correctly,
# "No" otherwise.
# """
# prompt = f'''
# You are an expert at validating SQL queries. Given the Natrual
# language description and the SQL query corresponding to that
# description, please check if the students answer is correct.
# There can be different ways to achieve the same result by forming
# the query differently. If the students SQL query matches the ground
# truth and fits the NL description correctly, then return yes
# else return no.
# Natural language description: {nl_description}
# Ground truth: {ground_truth}
# students answer: {llm_amswer}
# '''
# return llm(prompt)
def execute_sql_query(query, client, job_config):
"""Execute given SQL query to fetch result"""
try:
cleaned_query = query.replace("\\n", " ").replace("\n", " ")
cleaned_query = cleaned_query.replace("\\", "")
query_job = client.query(cleaned_query, job_config=job_config)
response = query_job.result().to_dataframe()
except Exception as e: # pylint: disable=broad-except
response = f"{str(e)}"
return response
def dbai_framework(question, bq_project_id, bq_dataset_id, tables_list=[]):
"""call DBAI to get nl2sql response"""
dbai_nl2sql = DBAI_nl2sql(
proj_id=bq_project_id,
dataset_id=bq_dataset_id,
tables_list=tables_list
)
return dbai_nl2sql.get_sql(question).generated_sql
def call_generate_sql_api(question, endpoint) -> tuple[str, str]:
"""
Common SQL generation function
"""
if LITE_API_PART in endpoint:
api_url = os.getenv('LITE_EXECUTORS')
few_shot_gen = False
if params["lite_model"] == FEW_SHOT_GENERATION:
few_shot_gen = True
data = {"question": question,
"execute_sql": params["execution"],
"few_shot": few_shot_gen}
else:
api_url = os.getenv('CORE_EXECUTORS')
data = {"question": question,
"execute_sql": params["execution"]}
headers = {"Content-type": "application/json",
"Authorization": f"Bearer {params['access_token']}"}
api_endpoint = f"{api_url}/{endpoint}"
logger.info(f"Invoking API : {api_endpoint}")
logger.info(f"Provided parameters are : Data = {data}")
api_response = requests.post(api_endpoint,
data=json.dumps(data),
headers=headers,
timeout=None)
exec_result = ""
try:
resp = api_response.json()
logger.info(f"API resonse : {resp}")
sql = resp['generated_query']
exec_result = resp['sql_result']
except RuntimeError:
sql = "Execution Failed ! Error encountered in RAG Executor"
logger.info(f"Generated SQL = {sql}")
return sql, exec_result
def db_setup(project_id, dataset_id, metadata_path, method):
""" """
token = "Bearer "
body = {
"proj_name": project_id,
"bq_dataset": dataset_id,
}
headers = {"Content-type": "application/json",
"Authorization": token}
if method == FEW_SHOT_GENERATION:
url = os.getenv(GEN_BY_LITE)
else:
url = os.getenv(GEN_BY_CORE)
if metadata_path:
with open(metadata_path, "r") as f:
string_data = f.read()
files = {"file": (metadata_path.split("/")[-1], string_data)}
body["metadata_file"] = metadata_path.split("/")[-1]
_ = requests.post(
url=url+"/projconfig",
data=json.dumps(body),
headers=headers,
timeout=None)
if metadata_path:
_ = requests.post(
url=url+"/uploadfile",
headers={"Authorization": token},
files=files,
timeout=None
)
def bq_evaluator(
bq_project_id,
bq_dataset_id,
ground_truth_path,
method,
metadata_path=None,
pb=None,
render_result=True,
):
""" """
ts = time.strftime("%y%m%d%H%M")
client = bigquery.Client(project=bq_project_id)
job_config = bigquery.QueryJobConfig(
maximum_bytes_billed=100000000,
default_dataset=f'{bq_project_id}.{bq_dataset_id}'
)
if method != "DBAI":
db_setup(bq_project_id, bq_dataset_id, metadata_path, method)
df = pd.read_csv(ground_truth_path)
all_results = []
for idx, (question, ground_truth_sql) in df.iterrows():
match method:
case "DBAI":
generated_query = dbai_framework(
question, bq_project_id, bq_dataset_id)
case _:
generated_query, _ = call_generate_sql_api(
question=question, endpoint=ENDPOINTS[method])
generated_query_result = execute_sql_query(generated_query,
client,
job_config
)
actual_query_result = execute_sql_query(ground_truth_sql,
client,
job_config
)
# llm_rating = auto_verify(question, ground_truth_sql, generated_query)
llm_rating = 'No'
result_eval = 0
try:
if generated_query_result.equals(actual_query_result):
result_eval = 1
else:
result_eval = 0
except RuntimeError:
result_eval = 0
out = [(
question, ground_truth_sql, actual_query_result, generated_query,
generated_query_result, llm_rating, result_eval
)]
all_results.extend(out)
out_df = pd.DataFrame(
out,
columns=[
'question',
'ground_truth_sql',
'actual_query_result',
'generated_query',
'generated_query_result',
'query_eval',
'result_eval'
])
out_df.to_csv(f'evaluation/eval_output/eval_result_{ts}.csv',
index=False,
mode='a'
)
if pb:
pb.progress(
(idx+1)/len(df),
text=f"Evaluation in progress. Please wait...{idx+1}/{len(df)}"
)
if render_result:
if idx == 0:
redndered_df = st.dataframe(out_df)
else:
redndered_df.add_rows(out_df)
pb.empty()
accuracy = out_df.result_eval.sum()/len(df)
all_results_df = pd.DataFrame(
all_results,
columns=[
'question', 'ground_truth_sql', 'actual_query_result',
'generated_query',
'generated_query_result',
'query_eval',
'result_eval'
])
print(f'Accuracy: {accuracy}')
return {
"accuracy": accuracy,
"output": all_results_df
}
if __name__ == '__main__':
BQ_PROJECT_ID = 'proj-kous'
BQ_DATASET_ID = 'nl2sql_fiserv'
GROUND_TRUTH_PATH = 'evaluation/fiserv_ground_truth.csv'
METADATA_PATH = "./nl2sql_src/cache_metadata/fiserv.json"
METHOD = "lite"
bq_evaluator(BQ_PROJECT_ID,
BQ_DATASET_ID,
GROUND_TRUTH_PATH,
METHOD,
METADATA_PATH
)