evalbench/work/sqlexecwork.py (99 lines of code) (raw):
"""Work is the base class for all work items."""
from typing import Any
from databases import DB
from work import Work
from util.sanitizer import sanitize_sql
from queue import Queue
import sqlparse
class SQLExecWork(Work):
"""SQLExecWork Generates SQL from the generator."""
def __init__(
self,
db: DB,
experiment_config: dict,
eval_result: dict,
db_queue: Queue,
):
self.db = db
self.experiment_config = experiment_config
self.eval_result = eval_result
self.db_queue = db_queue
def run(self, work_config: Any = None) -> dict:
"""Runs the work item.
Args:
work_config:
Returns:
"""
generated_result = None
generated_eval_result = None
generated_error = None
golden_result = None
golden_eval_result = None
golden_error = None
if (
self.eval_result["sql_generator_error"] is None
and self.eval_result["generated_sql"]
):
query_type = self.eval_result["query_type"]
eval_query = self._get_eval_query()
sanitized_generated_sql = self._sanitize_sql()
golden_sql = self._get_golden_sql()
if sanitized_generated_sql:
generated_result, generated_eval_result, generated_error = (
self._evaluate_execution_results(
sanitized_generated_sql, eval_query, query_type, is_golden=False
)
)
golden_result, golden_eval_result, golden_error = (
self._evaluate_execution_results(
golden_sql, eval_query, query_type, is_golden=True
)
)
self.eval_result["generated_result"] = generated_result
self.eval_result["eval_results"] = generated_eval_result
self.eval_result["generated_error"] = generated_error
self.eval_result["golden_result"] = golden_result
self.eval_result["golden_eval_results"] = golden_eval_result
self.eval_result["golden_error"] = golden_error
self.db_queue.put(self.db)
return self.eval_result
def _evaluate_execution_results(
self, query, eval_query, query_type, is_golden=False
):
result = None
eval_result = None
error = None
if query_type == "dql":
result, _, error = self.db.execute(sqlparse.split(query)[0], use_cache=True, rollback=True)
elif query_type == "dml":
# self.db.execute(self.eval_result["setup_sql"])
result, eval_result, error = self.db.execute(
query, eval_query, use_cache=False, rollback=True
)
# self.db.execute(self.eval_result["cleanup_sql"])
elif query_type == "ddl":
# self.db.execute(self.eval_result["setup_sql"])
try:
self.db.resetup_database(force=True)
except Exception as setup_error:
return (
None,
None,
f"Was not able to run DDL due to setup_error {setup_error}",
)
result, _, error = self.db.execute(query, use_cache=False)
eval_result = self.db.get_metadata()
# self.db.execute(self.eval_result["cleanup_sql"])
return result, eval_result, error
def _sanitize_sql(self):
if self.experiment_config["prompt_generator"] == "NOOPGenerator":
self.eval_result["sanitized_sql"] = self.eval_result["generated_sql"]
else:
self.eval_result["sanitized_sql"] = sanitize_sql(
self.eval_result["generated_sql"]
)
return self.eval_result["sanitized_sql"]
def _get_golden_sql(self):
golden_sql = ""
if isinstance(self.eval_result["golden_sql"], str):
golden_sql = self.eval_result["golden_sql"]
elif (
isinstance(self.eval_result["golden_sql"], list)
and len(self.eval_result["golden_sql"]) > 0
):
golden_sql = self.eval_result["golden_sql"][0]
return golden_sql
def _get_eval_query(self):
if self.eval_result["eval_query"] and len(self.eval_result["eval_query"]) > 0:
return self.eval_result["eval_query"][0]
else:
return None