dbconnectors/BQConnector.py (241 lines of code) (raw):

""" BigQuery Connector Class """ from google.cloud import bigquery from google.cloud import bigquery_connection_v1 as bq_connection from dbconnectors import DBConnector from abc import ABC from datetime import datetime import google.auth import pandas as pd from google.cloud.exceptions import NotFound def get_auth_user(): credentials, project_id = google.auth.default() if hasattr(credentials, 'service_account_email'): return credentials.service_account_email else: return "Not Determined" def bq_specific_data_types(): return ''' BigQuery offers a wide variety of datatypes to store different types of data effectively. Here's a breakdown of the available categories: Numeric Types - INTEGER (INT64): Stores whole numbers within the range of -9,223,372,036,854,775,808 to 9,223,372,036,854,775,807. Ideal for non-fractional values. FLOAT (FLOAT64): Stores approximate floating-point numbers with a range of -1.7E+308 to 1.7E+308. Suitable for decimals with a degree of imprecision. NUMERIC: Stores exact fixed-precision decimal numbers, with up to 38 digits of precision and 9 digits to the right of the decimal point. Useful for precise financial and accounting calculations. BIGNUMERIC: Similar to NUMERIC but with even larger scale and precision. Designed for extreme precision in calculations. Character Types - STRING: Stores variable-length Unicode character sequences. Enclosed using single, double, or triple quotes. Boolean Type - BOOLEAN: Stores logical values of TRUE or FALSE (case-insensitive). Date and Time Types - DATE: Stores dates without associated time information. TIME: Stores time information independent of a specific date. DATETIME: Stores both date and time information (without timezone information). TIMESTAMP: Stores an exact moment in time with microsecond precision, including a timezone component for global accuracy. Other Types BYTES: Stores variable-length binary data. Distinguished from strings by using 'B' or 'b' prefix in values. GEOGRAPHY: Stores points, lines, and polygons representing locations on the Earth's surface. ARRAY: Stores an ordered collection of zero or more elements of the same (non-ARRAY) data type. STRUCT: Stores an ordered collection of fields, each with its own name and data type (can be nested). This list covers the most common datatypes in BigQuery. ''' class BQConnector(DBConnector, ABC): """ A connector class for interacting with BigQuery databases. This class provides methods for connecting to BigQuery, executing queries, retrieving results as DataFrames, logging interactions, and managing embeddings. Attributes: project_id (str): The Google Cloud project ID where the BigQuery dataset resides. region (str): The region where the BigQuery dataset is located. dataset_name (str): The name of the BigQuery dataset to interact with. opendataqna_dataset (str): Name of the dataset to use for OpenDataQnA functionalities. audit_log_table_name (str): Name of the table to store audit logs. client (bigquery.Client): The BigQuery client instance for executing queries. Methods: getconn() -> bigquery.Client: Establishes a connection to BigQuery and returns a client object. retrieve_df(query) -> pd.DataFrame: Executes a SQL query and returns the results as a pandas DataFrame. make_audit_entry(source_type, user_grouping, model, question, generated_sql, found_in_vector, need_rewrite, failure_step, error_msg, FULL_LOG_TEXT) -> str: Logs an audit entry to BigQuery, recording details of the interaction and the generated SQL query. create_vertex_connection(connection_id) -> None: Creates a Vertex AI connection for remote model usage in BigQuery. create_embedding_model(connection_id, embedding_model) -> None: Creates or replaces an embedding model in BigQuery using a Vertex AI connection. retrieve_matches(mode, user_grouping, qe, similarity_threshold, limit) -> list: Retrieves the most similar table schemas, column schemas, or example queries based on the given mode and parameters. getSimilarMatches(mode, user_grouping, qe, num_matches, similarity_threshold) -> str: Returns a formatted string containing similar matches found for tables, columns, or examples. getExactMatches(query) -> str or None: Checks if the exact question is present in the example SQL set and returns the corresponding SQL query if found. test_sql_plan_execution(generated_sql) -> Tuple[bool, str]: Tests the execution plan of a generated SQL query in BigQuery. Returns a tuple indicating success and a message. return_table_schema_sql(dataset, table_names=None) -> str: Returns a SQL query to retrieve table schema information from a BigQuery dataset. return_column_schema_sql(dataset, table_names=None) -> str: Returns a SQL query to retrieve column schema information from a BigQuery dataset. """ def __init__(self, project_id:str, region:str, opendataqna_dataset:str, audit_log_table_name:str): self.project_id = project_id self.region = region self.opendataqna_dataset = opendataqna_dataset self.audit_log_table_name = audit_log_table_name self.client=self.getconn() def getconn(self): client = bigquery.Client(project=self.project_id) return client def retrieve_df(self,query): return self.client.query_and_wait(query).to_dataframe() def make_audit_entry(self, source_type, user_grouping, model, question, generated_sql, found_in_vector, need_rewrite, failure_step, error_msg, FULL_LOG_TEXT): # global FULL_LOG_TEXT auth_user=get_auth_user() PROJECT_ID = self.project_id table_id= PROJECT_ID+ '.' + self.opendataqna_dataset + '.' + self.audit_log_table_name now = datetime.now() table_exists=False client = self.getconn() df1 = pd.DataFrame(columns=[ 'source_type', 'project_id', 'user', 'user_grouping', 'model_used', 'question', 'generated_sql', 'found_in_vector', 'need_rewrite', 'failure_step', 'error_msg', 'execution_time', 'full_log' ]) new_row = { "source_type":source_type, "project_id":str(PROJECT_ID), "user":str(auth_user), "user_grouping": user_grouping, "model_used": model, "question": question, "generated_sql": generated_sql, "found_in_vector":found_in_vector, "need_rewrite":need_rewrite, "failure_step":failure_step, "error_msg":error_msg, "execution_time": now, "full_log": FULL_LOG_TEXT } df1.loc[len(df1)] = new_row db_schema=[ # Specify the type of columns whose type cannot be auto-detected. For # example the "title" column uses pandas dtype "object", so its # data type is ambiguous. bigquery.SchemaField("source_type", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("project_id", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("user", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("user_grouping", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("model_used", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("question", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("generated_sql", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("found_in_vector", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("need_rewrite", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("failure_step", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("error_msg", bigquery.enums.SqlTypeNames.STRING), bigquery.SchemaField("execution_time", bigquery.enums.SqlTypeNames.TIMESTAMP), bigquery.SchemaField("full_log", bigquery.enums.SqlTypeNames.STRING), ] try: client.get_table(table_id) # Make an API request. # print("Table {} already exists.".format(table_id)) table_exists=True except NotFound: print("Table {} is not found. Will create this log table".format(table_id)) table_exists=False if table_exists is True: # print('Performing streaming insert') errors = client.insert_rows_from_dataframe(table=table_id, dataframe=df1, selected_fields=db_schema) # Make an API request. if errors == [[]]: print("Logged the run") else: print("Encountered errors while inserting rows: {}".format(errors)) else: job_config = bigquery.LoadJobConfig(schema=db_schema,write_disposition="WRITE_TRUNCATE") # pandas_gbq.to_gbq(df1, table_id, project_id=PROJECT_ID) # replace to replace table; append to append to a table client.load_table_from_dataframe(df1,table_id,job_config=job_config) # replace to replace table; append to append to a table # df1.loc[len(df1)] = new_row # pandas_gbq.to_gbq(df1, table_id, project_id=PROJECT_ID, if_exists='append') # replace to replace table; append to append to a table # print('\n Query added to BQ log table \n') return 'Completed the logging step' def create_vertex_connection(self, connection_id : str): client=bq_connection.ConnectionServiceClient() cloud_resource_properties = bq_connection.types.CloudResourceProperties() new_connection=bq_connection.Connection(cloud_resource=cloud_resource_properties) response= client.create_connection(parent=f'projects/{self.project_id}/locations/{self.region}',connection=new_connection,connection_id=connection_id) def create_embedding_model(self,connection_id: str, embedding_model: str): client = self.getconn() client.query_and_wait(f'''CREATE OR REPLACE MODEL `{self.project_id}.{self.opendataqna_dataset}.EMBEDDING_MODEL` REMOTE WITH CONNECTION `{self.project_id}.{self.region}.{connection_id}` OPTIONS (ENDPOINT = '{embedding_model}');''') def retrieve_matches(self, mode, user_grouping, qe, similarity_threshold, limit): """ This function retrieves the most similar table_schema and column_schema. Modes can be either 'table', 'column', or 'example' """ matches = [] if mode == 'table': sql = '''select base.content as tables_content from vector_search( (SELECT * FROM `{}.table_details_embeddings` WHERE user_grouping = '{}'), "embedding", (SELECT {} as qe), top_k=> {},distance_type=>"COSINE") where 1-distance > {} ''' elif mode == 'column': sql='''select base.content as columns_content from vector_search( (SELECT * FROM `{}.tablecolumn_details_embeddings` WHERE user_grouping = '{}'), "embedding", (SELECT {} as qe), top_k=> {}, distance_type=>"COSINE") where 1-distance > {} ''' elif mode == 'example': sql='''select base.example_user_question, base.example_generated_sql from vector_search ( (SELECT * FROM `{}.example_prompt_sql_embeddings` WHERE user_grouping = '{}'), "embedding", (select {} as qe), top_k=> {}, distance_type=>"COSINE") where 1-distance > {} ''' else: ValueError("No valid mode. Must be either table, column, or example") name_txt = '' results=self.client.query_and_wait(sql.format('{}.{}'.format(self.project_id,self.opendataqna_dataset),user_grouping,qe,limit,similarity_threshold)).to_dataframe() # CHECK RESULTS if len(results) == 0: print(f"Did not find any results for {mode}. Adjust the query parameters.") else: print(f"Found {len(results)} similarity matches for {mode}.") if mode == 'table': name_txt = '' for _ , r in results.iterrows(): name_txt=name_txt+r["tables_content"]+"\n" elif mode == 'column': name_txt = '' for _ ,r in results.iterrows(): name_txt=name_txt+r["columns_content"]+"\n" elif mode == 'example': name_txt = '' for _ , r in results.iterrows(): example_user_question=r["example_user_question"] example_sql=r["example_generated_sql"] name_txt = name_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql else: ValueError("No valid mode. Must be either table, column, or example") name_txt = '' matches.append(name_txt) return matches def getSimilarMatches(self, mode, user_grouping, qe, num_matches, similarity_threshold): if mode == 'table': match_result= self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches) match_result = match_result[0] # print(match_result) elif mode == 'column': match_result= self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches) match_result = match_result[0] elif mode == 'example': match_result= self.retrieve_matches(mode, user_grouping, qe, similarity_threshold, num_matches) if len(match_result) == 0: match_result = None else: match_result = match_result[0] return match_result def getExactMatches(self, query): """Checks if the exact question is already present in the example SQL set""" check_history_sql=f"""SELECT example_user_question,example_generated_sql FROM `{self.project_id}.{self.opendataqna_dataset}.example_prompt_sql_embeddings` WHERE lower(example_user_question) = lower("{query}") LIMIT 1; """ exact_sql_history = self.client.query_and_wait(check_history_sql).to_dataframe() if exact_sql_history[exact_sql_history.columns[0]].count() != 0: sql_example_txt = '' exact_sql = '' for index, row in exact_sql_history.iterrows(): example_user_question=row["example_user_question"] example_sql=row["example_generated_sql"] exact_sql=example_sql sql_example_txt = sql_example_txt + "\n Example_question: "+example_user_question+ "; Example_SQL: "+example_sql # print("Found a matching question from the history!" + str(sql_example_txt)) final_sql=exact_sql else: print("No exact match found for the user prompt") final_sql = None return final_sql def test_sql_plan_execution(self, generated_sql): try: exec_result_df="" job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False) query_job = self.client.query(generated_sql,job_config=job_config) # print(query_job) exec_result_df=("This query will process {} bytes.".format(query_job.total_bytes_processed)) correct_sql = True print(exec_result_df) return correct_sql, exec_result_df except Exception as e: return False,str(e) def return_table_schema_sql(self, dataset, table_names=None): """ Returns the SQL query to be run on 'Source DB' to get the Table Schema The SQL query below returns a df containing the cols table_schema, table_name, table_description, table_columns (with cols in the table) for the schema specified above, e.g. 'retail' - table_schema: e.g. retail - table_name: name of the table inside the schema, e.g. products - table_description: text descriptor, can be empty - table_columns: aggregate of the col names inside the table """ user_dataset = self.project_id + '.' + dataset table_filter_clause = "" if table_names: # Extract individual table names from the input string #table_names = [name.strip() for name in table_names[1:-1].split(",")] # Handle the string as a list formatted_table_names = [f"'{name}'" for name in table_names] table_filter_clause = f"""AND TABLE_NAME IN ({', '.join(formatted_table_names)})""" table_schema_sql = f""" (SELECT TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name, OPTION_VALUE as table_description, (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns FROM `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS` as t WHERE OPTION_NAME = "description" {table_filter_clause} ORDER BY project_id, table_schema, table_name) UNION ALL (SELECT TABLE_CATALOG as project_id, TABLE_SCHEMA as table_schema , TABLE_NAME as table_name, "NA" as table_description, (SELECT STRING_AGG(column_name, ', ') from `{user_dataset}.INFORMATION_SCHEMA.COLUMNS` where TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) as table_columns FROM `{user_dataset}.INFORMATION_SCHEMA.TABLES` as t WHERE NOT EXISTS (SELECT 1 FROM `{user_dataset}.INFORMATION_SCHEMA.TABLE_OPTIONS` WHERE OPTION_NAME = "description" AND TABLE_NAME= t.TABLE_NAME and TABLE_SCHEMA=t.TABLE_SCHEMA) {table_filter_clause} ORDER BY project_id, table_schema, table_name) """ return table_schema_sql def return_column_schema_sql(self, dataset, table_names=None): """ Returns the SQL query to be run on 'Source DB' to get the column schema The SQL query below returns a df containing the cols table_schema, table_name, column_name, data_type, column_description, table_description, primary_key, column_constraints for the schema specified above, e.g. 'retail' - table_schema: e.g. retail - table_name: name of the tables inside the schema, e.g. products - column_name: name of each col in each table in the schema, e.g. id_product - data_type: data type of each col - column_description: col descriptor, can be empty - table_description: text descriptor, can be empty - primary_key: whether the col is PK; if yes, the field contains the col_name - column_constraints: e.g. "Primary key for this table" """ user_dataset = self.project_id + '.' + dataset table_filter_clause = "" if table_names: # table_names = [name.strip() for name in table_names[1:-1].split(",")] # Handle the string as a list formatted_table_names = [f"'{name}'" for name in table_names] table_filter_clause = f"""AND C.TABLE_NAME IN ({', '.join(formatted_table_names)})""" column_schema_sql = f""" SELECT C.TABLE_CATALOG as project_id, C.TABLE_SCHEMA as table_schema, C.TABLE_NAME as table_name, C.COLUMN_NAME as column_name, C.DATA_TYPE as data_type, C.DESCRIPTION as column_description, CASE WHEN T.CONSTRAINT_TYPE="PRIMARY KEY" THEN "This Column is a Primary Key for this table" WHEN T.CONSTRAINT_TYPE = "FOREIGN_KEY" THEN "This column is Foreign Key" ELSE NULL END as column_constraints FROM `{user_dataset}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` C LEFT JOIN `{user_dataset}.INFORMATION_SCHEMA.TABLE_CONSTRAINTS` T ON C.TABLE_CATALOG = T.TABLE_CATALOG AND C.TABLE_SCHEMA = T.TABLE_SCHEMA AND C.TABLE_NAME = T.TABLE_NAME AND T.ENFORCED ='YES' LEFT JOIN `{user_dataset}.INFORMATION_SCHEMA.KEY_COLUMN_USAGE` K ON K.CONSTRAINT_NAME=T.CONSTRAINT_NAME AND C.COLUMN_NAME = K.COLUMN_NAME WHERE 1=1 {table_filter_clause} ORDER BY project_id, table_schema, table_name, column_name; """ return column_schema_sql def get_column_samples(self,columns_df): sample_column_list=[] for index, row in columns_df.iterrows(): get_column_sample_sql=f'''SELECT STRING_AGG(CAST(value AS STRING)) as sample_values FROM UNNEST((SELECT APPROX_TOP_COUNT({row["column_name"]},5) as osn FROM `{row["project_id"]}.{row["table_schema"]}.{row["table_name"]}` ))''' column_samples_df=self.retrieve_df(get_column_sample_sql) # display(column_samples_df) sample_column_list.append(column_samples_df['sample_values'].to_string(index=False)) columns_df["sample_values"]=sample_column_list return columns_df