notebooks/util/jdbc/engines/oracle_input_manager.py (161 lines of code) (raw):

# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from textwrap import dedent from typing import List, Optional, Tuple import sqlalchemy from util.jdbc.jdbc_input_manager_interface import ( JDBCInputManagerInterface, JDBCInputManagerException, SPARK_PARTITION_COLUMN, SPARK_NUM_PARTITIONS, SPARK_LOWER_BOUND, SPARK_UPPER_BOUND, PARTITION_COMMENT, ) class OracleInputManager(JDBCInputManagerInterface): # Private methods def _build_table_list( self, schema_filter: Optional[str] = None, table_filter: Optional[List[str]] = None, ) -> Tuple[str, List[str]]: """ Return a tuple containing schema and list of table names based on optional schema/table filters. If schema_filter is not provided then the connected user is used for the schema. """ with self._alchemy_db.connect() as conn: schema = self._normalise_schema_filter(schema_filter, conn) not_like_filter = "table_name NOT LIKE 'DR$SUP_TEXT_IDX%'" if schema_filter: sql = f"SELECT table_name FROM all_tables WHERE owner = :own AND {not_like_filter}" rows = conn.execute(sqlalchemy.text(sql), {"own": schema}).fetchall() else: sql = f"SELECT table_name FROM user_tables WHERE {not_like_filter}" rows = conn.execute(sqlalchemy.text(sql)).fetchall() tables = [_[0] for _ in rows] if rows else rows return schema, self._filter_table_list(tables, table_filter) def _define_read_partitioning( self, table: str, row_count_threshold: int, sa_connection: "sqlalchemy.engine.base.Connection", custom_partition_column: Optional[str], ) -> str: """Return a dictionary defining how to partition the Spark SQL extraction.""" # TODO In the future we may want to support checking DBA_SEGMENTS row_count = self._get_table_count_from_stats(table, sa_connection=sa_connection) if not row_count or row_count < row_count_threshold: # In case this is a new table with no stats or a table with stale stats, do a full count row_count = self._get_table_count(table, sa_connection=sa_connection) if row_count < row_count_threshold: # The table does not have enough rows to merit partitioning Spark SQL read. return None accepted_data_types = ["NUMBER"] if custom_partition_column: # The user provided a partition column. column = self._normalise_column_name( table, custom_partition_column, sa_connection ) if not column: return { PARTITION_COMMENT: f"Serial read, column does not exist: {custom_partition_column}" } partition_options = self._define_native_column_read_partitioning( table, column, accepted_data_types, row_count, row_count_threshold, "user provided column", sa_connection, ) if partition_options: return partition_options # TODO Prioritise partition keys over primary keys in the future. # TODO Add support for unique keys alongside PKs. pk_cols = self.get_primary_keys().get(table) if pk_cols and len(pk_cols) == 1: # Partition by primary key singleton. column = pk_cols[0] partition_options = self._define_native_column_read_partitioning( table, column, accepted_data_types, row_count, row_count_threshold, "primary key column", sa_connection, ) if partition_options: return partition_options return None def _enclose_identifier(self, identifier, ch: Optional[str] = None): """Enclose an identifier in the standard way for the SQL engine.""" ch = ch or '"' return f"{ch}{identifier}{ch}" def _get_column_data_type( self, table: str, column: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> str: sql = dedent( """ SELECT data_type FROM all_tab_columns WHERE owner = :own AND table_name = :tab AND column_name = :col """ ) if sa_connection: row = sa_connection.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table, "col": column} ).fetchone() else: with self._alchemy_db.connect() as conn: row = conn.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table, "col": column}, ).fetchone() return self._normalise_oracle_data_type(row[0]) if row else row def _get_primary_keys(self) -> dict: """ Return a dict of primary key information. The dict is keyed on table name and maps to a list of column names. """ pk_dict = {_: None for _ in self._table_list} sql = dedent( """ SELECT cols.column_name FROM all_constraints cons , all_cons_columns cols WHERE cons.owner = :own AND cons.table_name = :tab AND cons.constraint_type = 'P' AND cons.status = 'ENABLED' AND cols.constraint_name = cons.constraint_name AND cols.owner = cons.owner AND cols.table_name = cons.table_name ORDER BY cols.position """ ) with self._alchemy_db.connect() as conn: for table in self._table_list: rows = conn.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table} ).fetchall() if rows: pk_dict[table] = [_[0] for _ in rows] return pk_dict def _get_table_count_from_stats( self, table: str, sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None, ) -> Optional[int]: """Return table count from stats gathering rather than running count(*).""" sql = dedent( """ SELECT num_rows FROM all_tables WHERE owner = :own AND table_name = :tab """ ) if sa_connection: row = sa_connection.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table} ).fetchone() else: with self._alchemy_db.connect() as conn: row = conn.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table} ).fetchone() return row[0] if row else row def _normalise_column_name( self, table: str, column: str, sa_connection: "sqlalchemy.engine.base.Connection", ) -> str: sql = dedent( """ SELECT column_name FROM all_tab_columns WHERE owner = :own AND table_name = :tab AND UPPER(column_name) = UPPER(:col)""" ) row = sa_connection.execute( sqlalchemy.text(sql), {"own": self._schema, "tab": table, "col": column} ).fetchone() return row[0] if row else row def _normalise_oracle_data_type(self, data_type: str) -> str: """Oracle TIMESTAMP types are polluted with scale, this method strips that noise away.""" if data_type.startswith("TIMESTAMP") or data_type.startswith("INTERVAL DAY"): return re.sub(r"\([0-9]\)", r"", data_type) else: return data_type def _normalise_schema_filter( self, schema_filter: str, sa_connection: "sqlalchemy.engine.base.Connection" ) -> str: """Return schema_filter normalised to the correct case, or sets to connected user if blank.""" if schema_filter: # Assuming there will not be multiple schemas of the same name in different case. sql = "SELECT username FROM all_users WHERE UPPER(username) = UPPER(:b1) ORDER BY username" row = sa_connection.execute( sqlalchemy.text(sql), {"b1": schema_filter} ).fetchone() if not row: raise JDBCInputManagerException( f"Schema filter does not match any Oracle schemas: {schema_filter}" ) else: sql = "SELECT USER FROM dual" row = sa_connection.execute(sqlalchemy.text(sql)).fetchone() return row[0] if row else row # Public methods