notebooks/util/jdbc/engines/mysql_input_manager.py (139 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from textwrap import dedent
from typing import List, Optional, Tuple
import sqlalchemy
from util.jdbc.jdbc_input_manager_interface import (
JDBCInputManagerInterface,
JDBCInputManagerException,
SPARK_PARTITION_COLUMN,
SPARK_NUM_PARTITIONS,
SPARK_LOWER_BOUND,
SPARK_UPPER_BOUND,
PARTITION_COMMENT,
)
class MySQLInputManager(JDBCInputManagerInterface):
# Private methods
def _build_table_list(
self,
schema_filter: Optional[str] = None,
table_filter: Optional[List[str]] = None,
) -> Tuple[str, List[str]]:
"""
Return a tuple containing schema and list of table names based on optional table filter.
schema_filter is unused because it is derived from connected database in MySQL.
"""
with self._alchemy_db.connect() as conn:
schema = self._normalise_schema_filter(schema_filter, conn)
sql = f"show tables;"
rows = conn.execute(sqlalchemy.text(sql)).fetchall()
tables = [_[0] for _ in rows] if rows else rows
return schema, self._filter_table_list(tables, table_filter)
def _define_read_partitioning(
self,
table: str,
row_count_threshold: int,
sa_connection: "sqlalchemy.engine.base.Connection",
custom_partition_column: Optional[str],
) -> str:
"""Return a dictionary defining how to partition the Spark SQL extraction."""
row_count = self._get_table_count_from_stats(table, sa_connection=sa_connection)
if not row_count:
# In case this is a new table with no stats, do a full count
row_count = self._get_table_count(table, sa_connection=sa_connection)
if row_count < int(row_count_threshold):
# The table does not have enough rows to merit partitioning Spark SQL read.
return None
accepted_data_types = ["int", "bigint", "mediumint"]
if custom_partition_column:
# The user provided a partition column.
column = self._normalise_column_name(
table, custom_partition_column, sa_connection
)
if not column:
return {
PARTITION_COMMENT: f"Serial read, column does not exist: {custom_partition_column}"
}
partition_options = self._define_native_column_read_partitioning(
table,
column,
accepted_data_types,
row_count,
row_count_threshold,
"user provided column",
sa_connection,
)
if partition_options:
return partition_options
pk_cols = self.get_primary_keys().get(table)
if pk_cols and len(pk_cols) == 1:
# Partition by primary key singleton.
column = pk_cols[0]
partition_options = self._define_native_column_read_partitioning(
table,
column,
accepted_data_types,
row_count,
row_count_threshold,
"primary key column",
sa_connection,
)
if partition_options:
return partition_options
return None
def _enclose_identifier(self, identifier, ch: Optional[str] = None):
"""Enclose an identifier in the standard way for the SQL engine."""
ch = ch or "`"
return f"{ch}{identifier}{ch}"
def _get_column_data_type(
self,
table: str,
column: str,
sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None,
) -> str:
# TODO Does MySQL support parameterised queries?
sql = dedent(
"""
SELECT data_type
FROM information_schema.columns
WHERE table_schema = '{}'
AND table_name = '{}'
AND column_name = '{}'
""".format(
self._schema, table, column
)
)
if sa_connection:
row = sa_connection.execute(sqlalchemy.text(sql)).fetchone()
else:
with self._alchemy_db.connect() as conn:
row = conn.execute(sqlalchemy.text(sql)).fetchone()
return row[0] if row else row
def _get_primary_keys(self) -> dict:
"""
Return a dict of primary key information.
The dict is keyed on table name and maps to a list of column names.
"""
pk_dict = {_: None for _ in self._table_list}
with self._alchemy_db.connect() as conn:
for table in self._table_list:
sql = "SHOW KEYS FROM {} WHERE Key_name = 'PRIMARY'".format(table)
rows = conn.execute(sqlalchemy.text(sql)).fetchall()
if rows:
pk_dict[table] = [_[4] for _ in rows]
return pk_dict
def _get_table_count_from_stats(
self,
table: str,
sa_connection: "Optional[sqlalchemy.engine.base.Connection]" = None,
) -> Optional[int]:
"""Return table count from stats gathering rather than running count(*)."""
sql = dedent(
"""
SELECT table_rows
FROM information_schema.tables
WHERE table_schema = '{}'
AND table_name = '{}'
""".format(
self._schema, table
)
)
if sa_connection:
row = sa_connection.execute(sqlalchemy.text(sql)).fetchone()
else:
with self._alchemy_db.connect() as conn:
row = conn.execute(sqlalchemy.text(sql)).fetchone()
return row[0] if row else row
def _normalise_column_name(
self,
table: str,
column: str,
sa_connection: "sqlalchemy.engine.base.Connection",
) -> str:
sql = dedent(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = '{}'
AND table_name = '{}'
AND UPPER(column_name) = UPPER('{}')
""".format(
self._schema, table, column
)
)
row = sa_connection.execute(sqlalchemy.text(sql)).fetchone()
return row[0] if row else row
def _normalise_schema_filter(
self, schema_filter: str, sa_connection: "sqlalchemy.engine.base.Connection"
) -> str:
"""Not used for MySQL."""
sql = "SELECT DATABASE()"
row = sa_connection.execute(sqlalchemy.text(sql)).fetchone()
if row and schema_filter and schema_filter.upper() != row[0].upper():
raise JDBCInputManagerException(
f"Schema filter does not match connected database: {schema_filter} != {row[0]}"
)
return row[0] if row else row