notebooks/util/jdbc/jdbc_input_manager_interface.py (235 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC as AbstractClass, abstractmethod from decimal import Decimal import math from typing import List, Optional, Tuple import pandas as pd import sqlalchemy SPARK_PARTITION_COLUMN = "partitionColumn" SPARK_NUM_PARTITIONS = "numPartitions" SPARK_LOWER_BOUND = "lowerBound" SPARK_UPPER_BOUND = "upperBound" PARTITION_COMMENT = "comment" class JDBCInputManagerException(Exception): pass class JDBCInputManagerInterface(AbstractClass): """Defines common code across each engine and enforces methods each engine should provide.""" def __init__(self, alchemy_db: "sqlalchemy.engine.base.Engine"): self._alchemy_db = alchemy_db self._schema = None self._table_list = [] self._pk_dict = {} # Abstract methods @abstractmethod def _build_table_list( self, schema_filter: Optional[str] = None, table_filter: Optional[List[str]] = None, ) -> Tuple[str, List[str]]: """Engine specific code to return a tuple containing schema and list of table names based on optional schema/table filters.""" @abstractmethod def _define_read_partitioning( self, table: str, row_count_threshold: int, sa_connection: "sqlalchemy.engine.base.Connection", custom_partition_column: Optional[str], ) -> str: """Return a dictionary defining how to partition the Spark SQL extraction.""" @abstractmethod def _enclose_identifier(self, identifier, ch: Optional[str] = None): """Enclose an identifier in the standard way for the SQL engine or override ch for any enclosure character.""" @abstractmethod def _get_column_data_type( self, table: str, column: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> str: """Return base data type for a column without any scale/precision/length annotation.""" @abstractmethod def _get_primary_keys(self) -> dict: """ Return a dict of primary key information. The dict is keyed on the qualified table name (e.g. 'schema.table_name') and maps to a list of primary key column names. """ @abstractmethod def _get_table_count_from_stats( self, table: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> Optional[int]: """Return table count from stats gathering rather than running count(*).""" @abstractmethod def _normalise_schema_filter( self, schema_filter: str, sa_connection: "sqlalchemy.engine.base.Connection" ) -> str: """Return schema_filter normalised to the correct case.""" # Private methods def _define_native_column_read_partitioning( self, table: str, column: str, accepted_data_types: List[str], row_count: int, row_count_threshold: int, column_description: str, sa_connection: "sqlalchemy.engine.base.Connection", ): column_datatype = self._get_column_data_type(table, column) if column_datatype in accepted_data_types: lowerbound = self._get_table_min(table, column, sa_connection=sa_connection) upperbound = self._get_table_max(table, column, sa_connection=sa_connection) if lowerbound and upperbound: # TODO Really we should define num_partitions as ceil(table row count / threshold) # and not as in _read_partitioning_num_partitions() but leaving logic # as-is for now, we can revisit in the future. num_partitions = self._read_partitioning_num_partitions( row_count, row_count_threshold ) return { SPARK_PARTITION_COLUMN: column, SPARK_NUM_PARTITIONS: num_partitions, SPARK_LOWER_BOUND: lowerbound, SPARK_UPPER_BOUND: upperbound, PARTITION_COMMENT: f"Partitioning by {column_datatype} {column_description}", } return None def _filter_table_list(self, table_list: List[str], table_filter: List[str]): """Returns table_list filtered for entries (case-insensitive) in table_filter.""" def table_name(s): """Cater for passing of row returned from SQL which will have the table_name in a list/tuple.""" return s[0] if isinstance(s, (list, tuple)) else s if table_filter: table_filter_upper = [_.upper() for _ in table_filter or []] return [ table_name(_) for _ in table_list if table_name(_).upper() in table_filter_upper ] else: return [table_name(_) for _ in table_list] def _get_count_sql(self, table: str) -> str: # This SQL should be simple enough to work on all engines but may need refactoring in the future. return "SELECT COUNT(*) FROM {}".format( self.qualified_name(self._schema, table, enclosed=True) ) def _get_max_sql(self, table: str, column: str) -> str: # This SQL should be simple enough to work on all engines but may need refactoring in the future. return "SELECT MAX({0}) FROM {1} WHERE {0} IS NOT NULL".format( self._enclose_identifier(column), self.qualified_name(self._schema, table, enclosed=True), ) def _get_min_sql(self, table: str, column: str) -> str: # This SQL should be simple enough to work on all engines but may need refactoring in the future. return "SELECT MIN({0}) FROM {1} WHERE {0} IS NOT NULL".format( self._enclose_identifier(column), self.qualified_name(self._schema, table, enclosed=True), ) def _get_table_count( self, table: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> Optional[int]: """Return row count for a table.""" sql = self._get_count_sql(table) if sa_connection: row = sa_connection.execute(sqlalchemy.text(sql)).fetchone() else: with self._alchemy_db.connect() as conn: row = conn.execute(sqlalchemy.text(sql)).fetchone() return row[0] if row else row def _get_table_min( self, table: str, column: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> Optional[int]: """Return min(column) for a table.""" sql = self._get_min_sql(table, column) if sa_connection: row = sa_connection.execute(sqlalchemy.text(sql)).fetchone() else: with self._alchemy_db.connect() as conn: row = conn.execute(sqlalchemy.text(sql)).fetchone() return row[0] if row else row def _get_table_max( self, table: str, column: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> Optional[int]: """Return max(column) for a table.""" sql = self._get_max_sql(table, column) if sa_connection: row = sa_connection.execute(sqlalchemy.text(sql)).fetchone() else: with self._alchemy_db.connect() as conn: row = conn.execute(sqlalchemy.text(sql)).fetchone() return row[0] if row else row def _read_partitioning_num_partitions(self, row_count: int, stride: int) -> int: """Return appropriate Spark SQL numPartition value for input range/row count.""" assert row_count >= 0 assert stride >= 0 if not row_count or not stride: return 1 return math.ceil(row_count / stride) # Public methods def build_table_list( self, schema_filter: Optional[list] = None, table_filter: Optional[List[str]] = None, ) -> List[str]: """ Return a list of (schema, table_name) tuples based on an optional schema filter. If schema_filter is not provided then the connected user is used for the schema. """ self._schema, self._table_list = self._build_table_list( schema_filter=schema_filter, table_filter=table_filter ) return self._table_list def define_read_partitioning( self, row_count_threshold: int, custom_partition_columns: Optional[dict] = None, ) -> dict: """ Return a dictionary defining how to partition the Spark SQL extraction. custom_partition_columns is an optional dict allowing the user to provide any column name as a read partition column. """ read_partition_info = {} # Case insensitive match for custom_partition_column in case user was imprecise. custom_partition_columns = { k.upper(): v for k, v in custom_partition_columns.items() } with self._alchemy_db.connect() as conn: for table in self._table_list: partition_options = self._define_read_partitioning( table, row_count_threshold, conn, custom_partition_column=(custom_partition_columns or {}).get( table.upper() ), ) if partition_options: read_partition_info[table] = partition_options return read_partition_info def get_schema(self) -> str: return self._schema def get_table_list(self) -> List[tuple]: return self._table_list def get_table_list_with_counts(self) -> List[int]: """Return a list of table counts in the same order as the list of tables.""" counts = [] with self._alchemy_db.connect() as conn: for table in self.get_table_list(): counts.append(self._get_table_count(table, sa_connection=conn)) return counts def get_primary_keys(self) -> dict: """ Return a dict of primary key information. The dict is keyed on the qualified table name (e.g. 'schema.table_name') and maps to a list of primary key column names. """ if not self._pk_dict: self._pk_dict = self._get_primary_keys() return self._pk_dict def normalise_schema(self, schema_filter: str) -> str: with self._alchemy_db.connect() as conn: self._schema = self._normalise_schema_filter(schema_filter, conn) return self._schema def qualified_name(self, schema: str, table: str, enclosed=False) -> str: if enclosed: return ( self._enclose_identifier(schema) + "." + self._enclose_identifier(table) ) else: return schema + "." + table def read_partitioning_df(self, read_partition_info: dict) -> pd.DataFrame: """Return a Pandas dataframe to allow tidy display of read partitioning information""" def get_read_partition_info(table, info_key): return read_partition_info.get(table, {}).get(info_key) report_dict = { "table": self.get_table_list(), "partition_column": [ get_read_partition_info(_, SPARK_PARTITION_COLUMN) for _ in self.get_table_list() ], "num_partitions": [ get_read_partition_info(_, SPARK_NUM_PARTITIONS) for _ in self.get_table_list() ], "lower_bound": [ get_read_partition_info(_, SPARK_LOWER_BOUND) for _ in self.get_table_list() ], "upper_bound": [ get_read_partition_info(_, SPARK_UPPER_BOUND) for _ in self.get_table_list() ], "comment": [ get_read_partition_info(_, PARTITION_COMMENT) or "Serial read" for _ in self.get_table_list() ], } return pd.DataFrame(report_dict) def set_table_list(self, table_list: list) -> None: self._table_list = table_list