evalbench/databases/bigquery.py (220 lines of code) (raw):

from google.cloud import bigquery import logging import re from .db import DB from .util import with_cache_execute, DatabaseSchema from util.rate_limit import rate_limit, ResourceExhaustedError from typing import List, Optional, Tuple, Any, Dict import json import sqlparse from google.cloud.bigquery import QueryJobConfig, ConnectionProperty from util.gcp import get_gcp_project from google.api_core.exceptions import GoogleAPICallError class BQDB(DB): ##################################################### ##################################################### # Database Connection Setup Logic ##################################################### ##################################################### def __init__(self, db_config): super().__init__(db_config) self.project_id = get_gcp_project("") self.location = db_config.get("location", "US") self.client = bigquery.Client(project=self.project_id) self.tmp_users = [] ##################################################### ##################################################### # Database Specific Execution Logic and Handling ##################################################### ##################################################### def _execute_queries(self, query: str, job_config: Optional[bigquery.QueryJobConfig] = None) -> List: result: List = [] for sub_query in sqlparse.split(query): if sub_query: resultset = self.client.query(sub_query, job_config) rows = resultset.result() if rows: for row in rows: result.append(dict(row)) return result def batch_execute(self, commands: list[str]): for command in commands: self.execute(command) def execute( self, query: str, eval_query: Optional[str] = None, use_cache=False, rollback=False ) -> Tuple[Any, Any, Any]: if query.strip() == "": return None, None, None if not use_cache or not self.cache_client or eval_query: return self._execute(query, eval_query, rollback) return with_cache_execute( query, f"{self.project_id}.{self.db_name}", self._execute, self.cache_client ) def _execute( self, query: str, eval_query: Optional[str] = None, rollback=False ) -> Tuple[Any, Any, Any]: def _run_execute(query: str, eval_query: Optional[str] = None, rollback=False): result: List = [] eval_result: List = [] error = None query_replaced = query.replace("{{dataset}}", self.db_name) if eval_query is not None: eval_query_replaced = eval_query.replace("{{dataset}}", self.db_name) try: if rollback: try: initial_query = "SELECT 1;" job_config = QueryJobConfig(create_session=True) init_job = self.client.query(initial_query, job_config=job_config) init_job.result() session_id = init_job.session_info.session_id conn_props = [ConnectionProperty(key="session_id", value=session_id)] self.client.query( "BEGIN TRANSACTION;", job_config=QueryJobConfig(connection_properties=conn_props) ).result() result = self._execute_queries(query_replaced, job_config=QueryJobConfig(connection_properties=conn_props)) if eval_query: eval_result = self._execute_queries(eval_query_replaced, job_config=QueryJobConfig(connection_properties=conn_props)) self.client.query( "ROLLBACK TRANSACTION;", job_config=QueryJobConfig(connection_properties=conn_props) ).result() except Exception as e: error = str(e) print(f"Error: {error}") finally: if 'session_id' in locals(): self.client.query( "CALL BQ.ABORT_SESSION();", job_config=QueryJobConfig(connection_properties=conn_props) ).result() if not rollback: result = self._execute_queries(query_replaced) if eval_query and not rollback: eval_result = self._execute_queries(eval_query_replaced) except (GoogleAPICallError, Exception) as e: error = str(e) if "resources exceeded" in error: raise ResourceExhaustedError(f"BigQuery resources exhausted: {e}") from e elif "quota exceeded" in error: raise ResourceExhaustedError(f"BigQuery quota exceeded: {e}") from e else: print(error) return result, eval_result, error try: return rate_limit( (query, eval_query, rollback), _run_execute, self.execs_per_minute, self.semaphore, self.max_attempts, ) except ResourceExhaustedError as e: logging.info( "Resource Exhausted on Postgres DB. Giving up execution. Try reducing execs_per_minute." ) return None, None, None def get_metadata(self) -> dict: metadata = {} try: for table in self.client.list_tables(self.db_name): schema = self.client.get_table(table.reference).schema metadata[table.table_id] = [{"name": f.name, "type": f.field_type} for f in schema] except Exception as e: print(f"Error while fetching metadata for dataset '{self.db_name}': {e}") return metadata ##################################################### ##################################################### # Setup / Teardown of temporary databases ##################################################### ##################################################### def generate_ddl(self, schema: DatabaseSchema) -> List[str]: ddl_statements = [] try: for table in schema.tables: columns = ", ".join([f"{col.name} {col.type}" for col in table.columns]) ddl_statements.append( f"CREATE TABLE `{self.project_id}.{self.db_name}.{table.name}` ({columns})" ) except Exception as e: print(f"Error generating DDL statements: {e}") return ddl_statements def create_tmp_database(self, database_name: str): dataset_ref = bigquery.Dataset(f"{self.project_id}.{database_name}") dataset_ref.location = self.location self.client.create_dataset(dataset_ref, exists_ok=True) self.tmp_dbs.append(database_name) def drop_tmp_database(self, database_name: str): try: self.client.delete_dataset( dataset=database_name, delete_contents=True, not_found_ok=True, ) if database_name in self.tmp_dbs: self.tmp_dbs.remove(database_name) except Exception as e: logging.warning(f"Failed to drop dataset {database_name}: {e}") def drop_all_tables(self): try: tables = list(self.client.list_tables(self.db_name)) if tables: for table in tables: full_table_id = f"{self.project_id}.{self.db_name}.{table.table_id}" self.client.delete_table(full_table_id) except Exception as e: raise RuntimeError(f"Failed to drop tables in dataset {self.db_name}: {e}") def _is_float(self, value) -> bool: try: float(value) except ValueError: return False return True def _get_column_name_to_type_mapping(self, sql_statements: List[str]) -> Dict[str, Dict[str, str]]: schema_mapping = {} for statement in sql_statements: table_match = re.search(r'CREATE TABLE\s+`{{dataset}}\.(\w+)`', statement) if not table_match: continue table_name = table_match.group(1) column_section_match = re.search(r'\(\n(.*?)\n\)', statement, re.DOTALL) if not column_section_match: continue columns_raw = column_section_match.group(1).split(",\n") column_type_mapping = {} for col in columns_raw: if col.strip().startswith("PRIMARY KEY"): continue col_parts = col.strip().split() if len(col_parts) >= 2: column_name = col_parts[0] column_type = col_parts[1] column_type_mapping[column_name] = column_type schema_mapping[table_name] = column_type_mapping return schema_mapping def insert_data(self, data: dict[str, List[str]], setup: Optional[List[str]] = None): if not data: return schema_mapping = self._get_column_name_to_type_mapping(setup) insertion_statements = [] for table_name in data: column_names = list(schema_mapping[table_name].keys()) for row in data[table_name]: formatted_values = [] for index, value in enumerate(row): col_name = column_names[index] col_type = schema_mapping[table_name][col_name].upper() if col_type == 'BOOL': if value == "'1'": formatted_values.append("TRUE") elif value == "'0'": formatted_values.append("FALSE") else: formatted_values.append(f"{value}") elif self._is_float(value): formatted_values.append(f"{value}") elif col_type == 'JSON': formatted_values.append(f"PARSE_JSON({value})") else: escaped_value = value.replace("''", "\\'") formatted_values.append(f"{escaped_value}") inline_columns = ", ".join(formatted_values) insertion_statements.append( f"INSERT INTO `{self.project_id}.{self.db_name}.{table_name}` VALUES ({inline_columns});" ) try: self.batch_execute(insertion_statements) except RuntimeError as error: raise RuntimeError(f"Could not insert data into database: {error}") ##################################################### ##################################################### # Database User Management ##################################################### ##################################################### def close_connections(self): pass def create_tmp_users(self, dql_user: str, dml_user: str, tmp_password: str): pass def delete_tmp_user(self, username: str): pass