data_validation/clients.py (249 lines of code) (raw):
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
import copy
import logging
from typing import TYPE_CHECKING
import warnings
import google.oauth2.service_account
from google.cloud import bigquery
from google.api_core import client_options
import ibis
import pandas
from data_validation import client_info, consts, exceptions
from data_validation.secret_manager import SecretManagerBuilder
from third_party.ibis.ibis_cloud_spanner.api import spanner_connect
from third_party.ibis.ibis_impala.api import impala_connect
from third_party.ibis.ibis_mssql.api import mssql_connect
from third_party.ibis.ibis_redshift.api import redshift_connect
if TYPE_CHECKING:
import ibis.expr.schema as sch
import ibis.expr.types as ir
ibis.options.sql.default_limit = None
# Filter Ibis MySQL error when loading client.table()
warnings.filterwarnings(
"ignore",
"`BaseBackend.database` is deprecated; use equivalent methods in the backend",
)
IBIS_ALCHEMY_BACKENDS = [
"mysql",
"oracle",
"postgres",
"db2",
"mssql",
"redshift",
"snowflake",
]
def _raise_missing_client_error(msg):
def get_client_call(*args, **kwargs):
raise Exception(msg)
return get_client_call
# Teradata requires teradatasql and licensing
try:
from third_party.ibis.ibis_teradata.api import teradata_connect
except Exception:
msg = "pip install teradatasql (requires Teradata licensing)"
teradata_connect = _raise_missing_client_error(msg)
# Oracle requires cx_Oracle driver
try:
from third_party.ibis.ibis_oracle.api import oracle_connect
except Exception:
oracle_connect = _raise_missing_client_error("pip install cx_Oracle")
# Snowflake requires snowflake-connector-python and snowflake-sqlalchemy
try:
from third_party.ibis.ibis_snowflake.api import snowflake_connect
except Exception:
snowflake_connect = _raise_missing_client_error(
"pip install snowflake-connector-python && pip install snowflake-sqlalchemy"
)
# DB2 requires ibm_db_sa
try:
from third_party.ibis.ibis_db2.api import db2_connect
except Exception:
db2_connect = _raise_missing_client_error("pip install ibm_db_sa")
def get_google_bigquery_client(
project_id: str, credentials=None, api_endpoint: str = None
):
info = client_info.get_http_client_info()
job_config = bigquery.QueryJobConfig(
connection_properties=[bigquery.ConnectionProperty("time_zone", "UTC")]
)
options = None
if api_endpoint:
options = client_options.ClientOptions(api_endpoint=api_endpoint)
return bigquery.Client(
project=project_id,
client_info=info,
credentials=credentials,
default_query_job_config=job_config,
client_options=options,
)
def get_bigquery_client(
project_id: str, dataset_id: str = "", credentials=None, api_endpoint: str = None
):
google_client = get_google_bigquery_client(
project_id, credentials=credentials, api_endpoint=api_endpoint
)
ibis_client = ibis.bigquery.connect(
project_id=project_id,
dataset_id=dataset_id,
credentials=credentials,
)
# Override the BigQuery client object to ensure the correct user agent is
# included and any api_endpoint is used.
ibis_client.client = google_client
return ibis_client
def get_pandas_client(table_name, file_path, file_type):
"""Return pandas client and env with file loaded into DataFrame
table_name (str): Table name to use as reference for file data
file_path (str): The local, s3, or GCS file path to the data
file_type (str): The file type of the file (csv, json, orc or parquet)
"""
if file_type == "csv":
df = pandas.read_csv(file_path)
elif file_type == "json":
df = pandas.read_json(file_path)
elif file_type == "orc":
df = pandas.read_orc(file_path)
elif file_type == "parquet":
df = pandas.read_parquet(file_path)
else:
raise ValueError(f"Unknown Pandas File Type: {file_type}")
pandas_client = ibis.pandas.connect({table_name: df})
return pandas_client
def is_sqlalchemy_backend(client):
try:
return bool(client.name in IBIS_ALCHEMY_BACKENDS)
except Exception:
return False
def is_oracle_client(client):
try:
return client.name == "oracle"
except TypeError:
# When no Oracle backend has been installed OracleBackend is not a class
return False
def get_ibis_table(client, schema_name, table_name, database_name=None):
"""Return Ibis Table for Supplied Client.
client (IbisClient): Client to use for table
schema_name (str): Schema name of table object
table_name (str): Table name of table object
database_name (str): Database name (generally default is used)
"""
if client.name in [
"oracle",
"postgres",
"db2",
"mssql",
"redshift",
]:
return client.table(table_name, database=database_name, schema=schema_name)
elif client.name == "pandas":
return client.table(table_name, schema=schema_name)
else:
return client.table(table_name, database=schema_name)
def get_ibis_query(client, query) -> "ir.Table":
"""Return Ibis Table from query expression for Supplied Client."""
iq = client.sql(query)
# Normalise all columns in the query to lower case.
# https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/992
iq = iq.relabel(dict(zip(iq.columns, [_.lower() for _ in iq.columns])))
return iq
def get_ibis_table_schema(client, schema_name: str, table_name: str) -> "sch.Schema":
"""Return Ibis Table Schema for Supplied Client.
client (IbisClient): Client to use for table
schema_name (str): Schema name of table object, may not need this since Backend uses database
table_name (str): Table name of table object
database_name (str): Database name (generally default is used)
"""
if is_sqlalchemy_backend(client):
return client.table(table_name, schema=schema_name).schema()
else:
return client.get_schema(table_name, schema_name)
def get_ibis_query_schema(client, query_str) -> "sch.Schema":
if is_sqlalchemy_backend(client):
ibis_query = get_ibis_query(client, query_str)
return ibis_query.schema()
else:
# NJ: I'm not happy about calling a private method but don't see how I can avoid it.
# Ibis does not expose a public method like it does for get_schema().
return client._get_schema_using_query(query_str)
def list_schemas(client):
"""Return a list of schemas in the DB."""
if hasattr(client, "list_databases"):
try:
return client.list_databases()
except NotImplementedError:
return [None]
else:
return [None]
def list_tables(client, schema_name, tables_only=True):
"""Return a list of tables in the DB schema."""
fn = (
client.dvt_list_tables
if tables_only and client.name != "pandas"
else client.list_tables
)
if client.name in ["db2", "mssql", "redshift", "snowflake", "pandas"]:
return fn()
return fn(database=schema_name)
def get_all_tables(client, allowed_schemas=None, tables_only=True):
"""Return a list of tuples with database and table names.
client (IbisClient): Client to use for tables
allowed_schemas (List[str]): List of schemas to pull.
"""
table_objs = []
schemas = list_schemas(client)
for schema_name in schemas:
if allowed_schemas and schema_name not in allowed_schemas:
continue
try:
tables = list_tables(client, schema_name, tables_only=tables_only)
except Exception as e:
logging.warning(f"List Tables Error: {schema_name} -> {e}")
continue
for table_name in tables:
table_objs.append((schema_name, table_name))
return table_objs
def get_data_client(connection_config):
"""Return DataClient client from given configuration"""
connection_config = copy.deepcopy(connection_config)
source_type = connection_config.pop(consts.SOURCE_TYPE)
secret_manager_type = connection_config.pop(consts.SECRET_MANAGER_TYPE, None)
secret_manager_project_id = connection_config.pop(
consts.SECRET_MANAGER_PROJECT_ID, None
)
decrypted_connection_config = {}
if secret_manager_type is not None:
sm = SecretManagerBuilder().build(secret_manager_type.lower())
for config_item in connection_config:
decrypted_connection_config[config_item] = sm.maybe_secret(
secret_manager_project_id, connection_config[config_item]
)
else:
decrypted_connection_config = connection_config
# The ibis_bigquery.connect expects a credentials object, not a string.
if consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH in decrypted_connection_config:
key_path = decrypted_connection_config.pop(
consts.GOOGLE_SERVICE_ACCOUNT_KEY_PATH
)
if key_path:
decrypted_connection_config[
"credentials"
] = google.oauth2.service_account.Credentials.from_service_account_file(
key_path
)
if source_type not in CLIENT_LOOKUP:
msg = 'ConfigurationError: Source type "{source_type}" is not supported'.format(
source_type=source_type
)
raise Exception(msg)
try:
data_client = CLIENT_LOOKUP[source_type](**decrypted_connection_config)
data_client._source_type = source_type
except Exception as e:
msg = 'Connection Type "{source_type}" could not connect: {error}'.format(
source_type=source_type, error=str(e)
)
raise exceptions.DataClientConnectionFailure(msg)
return data_client
@contextmanager
def get_data_client_ctx(*args, **kwargs):
"""Provide get_data_client() via a context manager."""
client = None
try:
client = get_data_client(*args, **kwargs)
yield client
finally:
# TODO When we upgrade Ibis beyond 5.x this try/except may become redundant.
# https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1376
if hasattr(client, "close"):
try:
client.close()
except Exception as exc:
# No need to reraise, we can silently fail if exiting throws up an issue.
logging.warning("Exception closing connection: %s", str(exc))
def get_max_column_length(client):
"""Return the max column length supported by client.
client (IbisClient): Client to use for tables
"""
if is_oracle_client(client):
# We can't reliably know which Version class client.version is stored in
# because it is out of our control. Therefore using string identification
# of Oracle <= 12.1 to avoid exceptions of this nature:
# TypeError: '<' not supported between instances of 'Version' and 'Version'
if str(client.version)[:2] in ["10", "11"] or str(client.version)[:4] == "12.1":
return 30
return 128
def get_max_in_list_size(client, in_list_over_expressions=False):
if client.name == "snowflake":
if in_list_over_expressions:
# This is a workaround for Snowflake limitation:
# SQL compilation error: In-list contains more than 50 non-constant values
# getattr(..., "cast") expression above is looking for lists where the contents are casts and not simple literals.
return 50
else:
return 16000
elif is_oracle_client(client):
# This is a workaround for Oracle limitation:
# ORA-01795: maximum number of expressions in a list is 1000
return 1000
else:
return None
CLIENT_LOOKUP = {
consts.SOURCE_TYPE_BIGQUERY: get_bigquery_client,
consts.SOURCE_TYPE_IMPALA: impala_connect,
consts.SOURCE_TYPE_MYSQL: ibis.mysql.connect,
consts.SOURCE_TYPE_ORACLE: oracle_connect,
consts.SOURCE_TYPE_FILESYSTEM: get_pandas_client,
consts.SOURCE_TYPE_POSTGRES: ibis.postgres.connect,
consts.SOURCE_TYPE_REDSHIFT: redshift_connect,
consts.SOURCE_TYPE_TERADATA: teradata_connect,
consts.SOURCE_TYPE_MSSQL: mssql_connect,
consts.SOURCE_TYPE_SNOWFLAKE: snowflake_connect,
consts.SOURCE_TYPE_SPANNER: spanner_connect,
consts.SOURCE_TYPE_DB2: db2_connect,
}