pathology/shared_libs/spanner_lib/cloud_spanner_client.py (91 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Wrapper for Cloud Spanner for use with DPAS database. Provides the utilities for read/write access to Cloud Spanner. """ from typing import Any, Dict, List, Optional, TypeVar from absl import flags from google.cloud import spanner from google.cloud.spanner_v1 import database from google.cloud.spanner_v1 import streamed from pathology.shared_libs.flags import secret_flag_utils from pathology.shared_libs.logging_lib import cloud_logging_client INSTANCE_ID_FLG = flags.DEFINE_string( 'instance_id', secret_flag_utils.get_secret_or_env('INSTANCE_ID', None), 'Instance Id of instance containing database.', ) DATABASE_NAME_FLG = flags.DEFINE_string( 'database_name', secret_flag_utils.get_secret_or_env('DATABASE_NAME', None), 'Name of database.', ) CloudSpannerClientType = TypeVar( 'CloudSpannerClientType', bound='CloudSpannerClient' ) class CloudSpannerClientInstanceExceptionError(Exception): pass class CloudSpannerClient: """Wrapper for Cloud Spanner for use with DPAS database.""" def __init__(self): """Initializes a Cloud Spanner Client wrapper.""" cloud_logging_client.info( 'Connecting to Spanner instance.', { 'instance_id': INSTANCE_ID_FLG.value, 'database_name': DATABASE_NAME_FLG.value, }, ) self._database = ( spanner.Client() .instance(INSTANCE_ID_FLG.value) .database(DATABASE_NAME_FLG.value) ) def execute_sql( self, sql: str, params: Optional[Dict[str, Any]] = None, param_types: Optional[Dict[str, Any]] = None, ) -> streamed.StreamedResultSet: """Executes an sql statement on a snapshot of the database. Params and param_types fields should be specified in non-test usage to sanitize user data. Args: sql: SQL statement to execute params: Key value pairs of names to params. {str -> Any}. param_types: Key value pairs of param types corresponding to params {str -> spanner.param_types.Type}. Returns: StreamedResultSet which can be used to consume rows """ with self._database.snapshot() as snapshot: return snapshot.execute_sql(sql, params, param_types) def execute_dml(self, dml: List[str]) -> List[int]: """Executes DML in a read-write transaction. DML supports read after write transactions allowing uncommitted data to be read in the same transaction it was written. Args: dml: DML statements to execute in order Returns: List[int] of rows updated """ def execute_update(transaction, statement): return transaction.execute_update(statement) row_ct = [] for st in dml: row_ct.append(self._database.run_in_transaction(execute_update, st)) return row_ct def run_in_transaction(self, func, *args, **kw): """Runs a process of work in a transaction. Ensures that read/writes in a function occur in the same transaction. Args: func: Function to run. *args: Additional arguments to be passed to func. **kw: Keyword arguments to be passed to func. Returns: return of func """ return self._database.run_in_transaction(func, *args, **kw) def get_database_snapshot(self) -> database.SnapshotCheckout: """Returns a snapshot of the Spanner database.""" return self._database.snapshot(multi_use=True) def read_data( self, table: str, columns: List[str], keys: spanner.KeySet, index: Optional[str] = None, ) -> streamed.StreamedResultSet: """Reads sample data from a snapshot of the database. Args: table: Name of table to read from. columns: Column names of table. keys: KeySet defining keys to read. index: Secondary index to use instead of primary key. Returns: StreamedResultSet which can be used to consume rows """ with self._database.snapshot() as snapshot: return snapshot.read(table, columns, keys, index) def insert_data( self, table: str, columns: List[str], values: List[List[Any]] ) -> None: """Inserts data using mutations. A mutation is a sequence of inserts, updates, and deletes that get applied atomically. All mutations in a single batch are applied atomically. Does not allow for read of uncommitted data. Args: table: Name of table to insert into. columns: Column names. values: Values to insert. Returns: None """ with self._database.batch() as batch: batch.insert(table, columns, values) def update_data( self, table: str, columns: List[str], values: List[List[Any]] ) -> None: """Updates data using mutations. A mutation is a sequence of inserts, updates, and deletes that get applied atomically. All mutations in a single batch are applied atomically. Does not allow for read of uncommitted data. Args: table: Name of table to update columns: Column names values: Values to update Returns: None """ with self._database.batch() as batch: batch.update(table, columns, values) def insert_or_update_data( self, table: str, columns: List[str], values: List[List[Any]] ) -> None: """Inserts or updates data using mutations. A mutation is a sequence of inserts, updates, and deletes that get applied atomically. All mutations in a single batch are applied atomically. Does not allow for read of uncommitted data. Args: table: Name of table to update columns: Column names values: Values to update Returns: None """ with self._database.batch() as batch: batch.insert_or_update(table, columns, values) def delete_data( self, table: str, keys: Optional[List[List[Any]]] = None, all_keys: bool = False, ) -> None: """Deletes data using mutations. A mutation is a sequence of inserts, updates, and deletes that get applied atomically. All mutations in a single batch are applied atomically. Does not allow for read of uncommitted data. Args: table: Name of table to update keys: List of primary key values of rows to delete all_keys: True to delete all keys Returns: None """ if all_keys: with self._database.batch() as batch: batch.delete(table, spanner.KeySet(all_=all_keys)) else: with self._database.batch() as batch: batch.delete(table, spanner.KeySet(keys=keys))