in UI/evaluation/eval.py [0:0]
def bq_evaluator(
bq_project_id,
bq_dataset_id,
ground_truth_path,
method,
metadata_path=None,
pb=None,
render_result=True,
):
""" """
ts = time.strftime("%y%m%d%H%M")
client = bigquery.Client(project=bq_project_id)
job_config = bigquery.QueryJobConfig(
maximum_bytes_billed=100000000,
default_dataset=f'{bq_project_id}.{bq_dataset_id}'
)
if method != "DBAI":
db_setup(bq_project_id, bq_dataset_id, metadata_path, method)
df = pd.read_csv(ground_truth_path)
all_results = []
for idx, (question, ground_truth_sql) in df.iterrows():
match method:
case "DBAI":
generated_query = dbai_framework(
question, bq_project_id, bq_dataset_id)
case _:
generated_query, _ = call_generate_sql_api(
question=question, endpoint=ENDPOINTS[method])
generated_query_result = execute_sql_query(generated_query,
client,
job_config
)
actual_query_result = execute_sql_query(ground_truth_sql,
client,
job_config
)
# llm_rating = auto_verify(question, ground_truth_sql, generated_query)
llm_rating = 'No'
result_eval = 0
try:
if generated_query_result.equals(actual_query_result):
result_eval = 1
else:
result_eval = 0
except RuntimeError:
result_eval = 0
out = [(
question, ground_truth_sql, actual_query_result, generated_query,
generated_query_result, llm_rating, result_eval
)]
all_results.extend(out)
out_df = pd.DataFrame(
out,
columns=[
'question',
'ground_truth_sql',
'actual_query_result',
'generated_query',
'generated_query_result',
'query_eval',
'result_eval'
])
out_df.to_csv(f'evaluation/eval_output/eval_result_{ts}.csv',
index=False,
mode='a'
)
if pb:
pb.progress(
(idx+1)/len(df),
text=f"Evaluation in progress. Please wait...{idx+1}/{len(df)}"
)
if render_result:
if idx == 0:
redndered_df = st.dataframe(out_df)
else:
redndered_df.add_rows(out_df)
pb.empty()
accuracy = out_df.result_eval.sum()/len(df)
all_results_df = pd.DataFrame(
all_results,
columns=[
'question', 'ground_truth_sql', 'actual_query_result',
'generated_query',
'generated_query_result',
'query_eval',
'result_eval'
])
print(f'Accuracy: {accuracy}')
return {
"accuracy": accuracy,
"output": all_results_df
}