evalbench/evaluator/db_manager.py (94 lines of code) (raw):

from queue import Queue from copy import deepcopy from databases import DB, get_database from util.config import load_db_data_from_csvs, load_setup_scripts from concurrent.futures import ThreadPoolExecutor from functools import partial from typing import Optional def build_db_queue( core_db: DB, db_name, db_config, setup_config, query_type: str, num_dbs: int ): if query_type == "dql": return _prepare_db_queue_for_dql( core_db, db_name, db_config, setup_config, num_dbs ) elif query_type == "dml": return _prepare_db_queue_for_dml( core_db, db_name, db_config, setup_config, num_dbs ) elif query_type == "ddl": return _prepare_db_queue_for_ddl( core_db, db_name, db_config, setup_config, num_dbs ) return Queue[DB]() def _prepare_db_queue_for_dql(core_db: DB, db_name, db_config, setup_config, num_dbs): """For DQL, use the same single DB with a user that has only DQL access.""" db_queue = Queue[DB]() dql_db_config = deepcopy(db_config) if setup_config: setup_scripts, data = _get_setup_values( setup_config, db_name, db_config.get("db_type") ) core_db.set_setup_instructions(setup_scripts, data) core_db.resetup_database(False, True) dql_db_config["user_name"] = core_db.get_dql_user() dql_db_config["password"] = core_db.get_tmp_user_password() singular_db = get_database(dql_db_config, db_name) for _ in range(num_dbs): db_queue.put(singular_db) return db_queue def _prepare_db_queue_for_dml(core_db: DB, db_name, db_config, setup_config, num_dbs): """For DML, use the same single DB with a user that has only DQL / DML access.""" db_queue = Queue[DB]() dml_db_config = deepcopy(db_config) if setup_config: setup_scripts, data = _get_setup_values( setup_config, db_name, db_config.get("db_type") ) core_db.set_setup_instructions(setup_scripts, data) core_db.resetup_database(False, True) dml_db_config["user_name"] = core_db.get_dml_user() dml_db_config["password"] = core_db.get_tmp_user_password() singular_db = get_database(dml_db_config, db_name) for _ in range(num_dbs): db_queue.put(singular_db) return db_queue def _prepare_db_queue_for_ddl(core_db: DB, db_name, db_config, setup_config, num_dbs): """For DDL, use the same single DB with a user that has only DDL access.""" if setup_config: setup_scripts, _ = _get_setup_values( setup_config, db_name, db_config.get("db_type") ) core_db.set_setup_instructions(setup_scripts, None) core_db.resetup_database(False, False) db_queue = Queue[DB]() if not setup_config: raise ValueError("No Setup Config was provided for DDL") setup_scripts, _ = _get_setup_values( setup_config, db_name, db_config.get("db_type") ) tmp_dbs = core_db.create_tmp_databases(num_dbs) with ThreadPoolExecutor() as executor: create_ddl_tmp_db_p = partial( _create_ddl_tmp_db, db_config=db_config, setup_scripts=setup_scripts ) results = executor.map(create_ddl_tmp_db_p, tmp_dbs) for tmp_db in results: db_queue.put(tmp_db) return db_queue def _create_ddl_tmp_db(tmp_db, db_config, setup_scripts): tmp_ddl_db_config = deepcopy(db_config) tmp_ddl_db_config["is_tmp_db"] = True tmp_db = get_database(tmp_ddl_db_config, tmp_db) tmp_db.set_setup_instructions(setup_scripts, None) return tmp_db def _get_setup_values(setup_config, db_name: str, db_type: str): try: setup_scripts = load_setup_scripts( setup_config["setup_directory"] + "/" + db_name + "/" + db_type ) data = load_db_data_from_csvs( setup_config["setup_directory"] + "/" + db_name + "/data" ) return setup_scripts, data except Exception as e: raise FileNotFoundError( f"Could not find setup files for database {db_name} on {db_type} due to: {e}" )