bq-connector/docai_bq_connector/bigquery/StorageManager.py (84 lines of code) (raw):

# # Copyright 2022 Google LLC # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import logging from google.cloud import bigquery from google.cloud.exceptions import NotFound class StorageManager: def __init__(self, project_id: str, dataset_id: str): self.project_id = project_id self.dataset_id = dataset_id self.client: bigquery.client.Client = None if project_id is None: self.client = bigquery.Client() self.project_id = self.client.project else: self.client = bigquery.Client(project=project_id) self.dataset_ref = bigquery.DatasetReference(self.project_id, self.dataset_id) def _does_dataset_exist(self, dataset_ref) -> bool: try: self.client.get_dataset(dataset_ref) logging.debug("Dataset %s already exists", dataset_ref) return True except NotFound: logging.debug("Dataset %s is not found", dataset_ref) return False def does_table_exist(self, name): table_ref = bigquery.TableReference(self.dataset_ref, name) try: self.client.get_table(table_ref) logging.debug("Table %s already exists.", table_ref.table_id) return True except NotFound: logging.debug("Table %s is not found.", table_ref.table_id) return False def write_record(self, table_id: str, record): table_ref = bigquery.TableReference(self.dataset_ref, table_id) errors = self.client.insert_rows_json(table_ref, [record]) if errors: logging.error("Encountered errors while inserting rows: %s", errors) return errors def update_record( self, table_id: str, record_id_name, record_id_value, cols_to_update ): # Assumes all columns and key are of type STRING query_params = [] dml_statement = f"UPDATE `{self.project_id}.{self.dataset_id}.{table_id}` SET" if len(cols_to_update) == 0: # Nothing to do return idx = 0 for index, (cur_col, cur_val) in enumerate(cols_to_update.items()): dml_statement = f"{dml_statement} {cur_col} = @param_{idx}," cur_qp = bigquery.ScalarQueryParameter(f"param_{idx}", "STRING", cur_val) query_params.append(cur_qp) idx += 1 # Remove last comma dml_statement = ( f"{dml_statement[:-1]} WHERE {record_id_name} = @param_record_id" ) cur_qp = bigquery.ScalarQueryParameter( "param_record_id", "STRING", record_id_value ) query_params.append(cur_qp) logging.debug( f"About to run query: {dml_statement} with params: {query_params}" ) query_job_config = bigquery.QueryJobConfig( use_legacy_sql=False, query_parameters=query_params ) query_job = self.client.query(query=dml_statement, job_config=query_job_config) query_job.result() def get_records(self, query: str, query_params=[]): # Only supports scalar parameters bq_q_params = [] for qp in query_params: logging.debug(qp) cur_p = bigquery.ScalarQueryParameter(qp["name"], qp["type"], qp["value"]) bq_q_params.append(cur_p) logging.debug(f"About to run query: {query} with params: {bq_q_params}") query_job_config = bigquery.QueryJobConfig( use_legacy_sql=False, query_parameters=bq_q_params ) query_job = self.client.query(query=query, job_config=query_job_config) bq_rows = query_job.result() # Convert to array of dict records = [dict(row) for row in bq_rows] logging.debug(f"Will return {records}") return records def get_table_schema(self, table_id: str): table_ref = bigquery.TableReference(self.dataset_ref, table_id) table = self.client.get_table(table_ref) return table.schema