superset/db_engine_specs/databricks.py (430 lines of code) (raw):

# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations from datetime import datetime from typing import Any, TYPE_CHECKING, TypedDict, Union from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.validate import Range from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from superset.constants import TimeGrain from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.db_engine_specs.hive import HiveEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.utils import json from superset.utils.core import get_user_agent, QuerySource from superset.utils.network import is_hostname_valid, is_port_open if TYPE_CHECKING: from superset.models.core import Database class DatabricksBaseSchema(Schema): """ Fields that are required for both Databricks drivers that uses a dynamic form. """ access_token = fields.Str(required=True) host = fields.Str(required=True) port = fields.Integer( required=True, metadata={"description": __("Database port")}, validate=Range(min=0, max=2**16, max_inclusive=False), ) encryption = fields.Boolean( required=False, metadata={"description": __("Use an encrypted connection to the database")}, ) class DatabricksBaseParametersType(TypedDict): """ The parameters are all the keys that do not exist on the Database model. These are used to build the sqlalchemy uri. """ access_token: str host: str port: int encryption: bool class DatabricksNativeSchema(DatabricksBaseSchema): """ Additional fields required only for the DatabricksNativeEngineSpec. """ database = fields.Str(required=True) class DatabricksNativePropertiesSchema(DatabricksNativeSchema): """ Properties required only for the DatabricksNativeEngineSpec. """ http_path = fields.Str(required=True) class DatabricksNativeParametersType(DatabricksBaseParametersType): """ Additional parameters required only for the DatabricksNativeEngineSpec. """ database: str class DatabricksNativePropertiesType(TypedDict): """ All properties that need to be available to the DatabricksNativeEngineSpec in order tocreate a connection if the dynamic form is used. """ parameters: DatabricksNativeParametersType extra: str class DatabricksPythonConnectorSchema(DatabricksBaseSchema): """ Additional fields required only for the DatabricksPythonConnectorEngineSpec. """ http_path_field = fields.Str(required=True) default_catalog = fields.Str(required=True) default_schema = fields.Str(required=True) class DatabricksPythonConnectorParametersType(DatabricksBaseParametersType): """ Additional parameters required only for the DatabricksPythonConnectorEngineSpec. """ http_path_field: str default_catalog: str default_schema: str class DatabricksPythonConnectorPropertiesType(TypedDict): """ All properties that need to be available to the DatabricksPythonConnectorEngineSpec in order to create a connection if the dynamic form is used. """ parameters: DatabricksPythonConnectorParametersType extra: str time_grain_expressions: dict[str | None, str] = { None: "{col}", TimeGrain.SECOND: "date_trunc('second', {col})", TimeGrain.MINUTE: "date_trunc('minute', {col})", TimeGrain.HOUR: "date_trunc('hour', {col})", TimeGrain.DAY: "date_trunc('day', {col})", TimeGrain.WEEK: "date_trunc('week', {col})", TimeGrain.MONTH: "date_trunc('month', {col})", TimeGrain.QUARTER: "date_trunc('quarter', {col})", TimeGrain.YEAR: "date_trunc('year', {col})", TimeGrain.WEEK_ENDING_SATURDAY: ( "date_trunc('week', {col} + interval '1 day') + interval '5 days'" ), TimeGrain.WEEK_STARTING_SUNDAY: ( "date_trunc('week', {col} + interval '1 day') - interval '1 day'" ), } class DatabricksHiveEngineSpec(HiveEngineSpec): engine_name = "Databricks Interactive Cluster" engine = "databricks" drivers = {"pyhive": "Hive driver for Interactive Cluster"} default_driver = "pyhive" _show_functions_column = "function" _time_grain_expressions = time_grain_expressions class DatabricksBaseEngineSpec(BaseEngineSpec): _time_grain_expressions = time_grain_expressions @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None ) -> str | None: return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra) @classmethod def epoch_to_dttm(cls) -> str: return HiveEngineSpec.epoch_to_dttm() class DatabricksODBCEngineSpec(DatabricksBaseEngineSpec): engine_name = "Databricks SQL Endpoint" engine = "databricks" drivers = {"pyodbc": "ODBC driver for SQL endpoint"} default_driver = "pyodbc" class DatabricksDynamicBaseEngineSpec(BasicParametersMixin, DatabricksBaseEngineSpec): default_driver = "" encryption_parameters = {"ssl": "1"} required_parameters = {"access_token", "host", "port"} context_key_mapping = { "access_token": "password", "host": "hostname", "port": "port", } @staticmethod def get_extra_params( database: Database, source: QuerySource | None = None ) -> dict[str, Any]: """ Add a user agent to be used in the requests. Trim whitespace from connect_args to avoid databricks driver errors """ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database, source) engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) user_agent = get_user_agent(database, source) connect_args.setdefault("http_headers", [("User-Agent", user_agent)]) connect_args.setdefault("_user_agent_entry", user_agent) # trim whitespace from http_path to avoid databricks errors on connecting if http_path := connect_args.get("http_path"): connect_args["http_path"] = http_path.strip() return extra @classmethod def get_table_names( cls, database: Database, inspector: Inspector, schema: str | None, ) -> set[str]: return super().get_table_names( database, inspector, schema ) - cls.get_view_names(database, inspector, schema) @classmethod def extract_errors( cls, ex: Exception, context: dict[str, Any] | None = None ) -> list[SupersetError]: raw_message = cls._extract_error_message(ex) context = context or {} # access_token isn't currently parseable from the # databricks error response, but adding it in here # for reference if their error message changes for key, value in cls.context_key_mapping.items(): context[key] = context.get(value) for regex, (message, error_type, extra) in cls.custom_errors.items(): match = regex.search(raw_message) if match: params = {**context, **match.groupdict()} extra["engine_name"] = cls.engine_name return [ SupersetError( error_type=error_type, message=message % params, level=ErrorLevel.ERROR, extra=extra, ) ] return [ SupersetError( error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, message=cls._extract_error_message(ex), level=ErrorLevel.ERROR, extra={"engine_name": cls.engine_name}, ) ] @classmethod def validate_parameters( # type: ignore cls, properties: Union[ DatabricksNativePropertiesType, DatabricksPythonConnectorPropertiesType, ], ) -> list[SupersetError]: errors: list[SupersetError] = [] if extra := json.loads(properties.get("extra")): # type: ignore engine_params = extra.get("engine_params", {}) connect_args = engine_params.get("connect_args", {}) parameters = { **properties, **properties.get("parameters", {}), } if connect_args.get("http_path"): parameters["http_path"] = connect_args.get("http_path") present = {key for key in parameters if parameters.get(key, ())} if missing := sorted(cls.required_parameters - present): errors.append( SupersetError( message=f"One or more parameters are missing: {', '.join(missing)}", error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, level=ErrorLevel.WARNING, extra={"missing": missing}, ), ) host = parameters.get("host", None) if not host: return errors if not is_hostname_valid(host): # type: ignore errors.append( SupersetError( message="The hostname provided can't be resolved.", error_type=SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR, level=ErrorLevel.ERROR, extra={"invalid": ["host"]}, ), ) return errors port = parameters.get("port", None) if not port: return errors try: port = int(port) # type: ignore except (ValueError, TypeError): errors.append( SupersetError( message="Port must be a valid integer.", error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, level=ErrorLevel.ERROR, extra={"invalid": ["port"]}, ), ) if not (isinstance(port, int) and 0 <= port < 2**16): errors.append( SupersetError( message=( "The port must be an integer between 0 and 65535 (inclusive)." ), error_type=SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, level=ErrorLevel.ERROR, extra={"invalid": ["port"]}, ), ) elif not is_port_open(host, port): # type: ignore errors.append( SupersetError( message="The port is closed.", error_type=SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR, level=ErrorLevel.ERROR, extra={"invalid": ["port"]}, ), ) return errors class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): engine = "databricks" engine_name = "Databricks (legacy)" drivers = {"connector": "Native all-purpose driver"} default_driver = "connector" parameters_schema = DatabricksNativeSchema() properties_schema = DatabricksNativePropertiesSchema() sqlalchemy_uri_placeholder = ( "databricks+connector://token:{access_token}@{host}:{port}/{database_name}" ) context_key_mapping = { **DatabricksDynamicBaseEngineSpec.context_key_mapping, "database": "database", "username": "username", } required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | { "database", "extra", } supports_dynamic_schema = True supports_catalog = True supports_dynamic_catalog = True supports_cross_catalog_queries = True @classmethod def build_sqlalchemy_uri( # type: ignore cls, parameters: DatabricksNativeParametersType, *_ ) -> str: query = {} if parameters.get("encryption"): if not cls.encryption_parameters: raise Exception( # pylint: disable=broad-exception-raised "Unable to build a URL with encryption enabled" ) query.update(cls.encryption_parameters) return str( URL.create( f"{cls.engine}+{cls.default_driver}".rstrip("+"), username="token", password=parameters.get("access_token"), host=parameters["host"], port=parameters["port"], database=parameters["database"], query=query, ) ) @classmethod def get_parameters_from_uri( # type: ignore cls, uri: str, *_, **__ ) -> DatabricksNativeParametersType: url = make_url_safe(uri) encryption = all( item in url.query.items() for item in cls.encryption_parameters.items() ) return { "access_token": url.password, "host": url.host, "port": url.port, "database": url.database, "encryption": encryption, } @classmethod def parameters_json_schema(cls) -> Any: """ Return configuration parameters as OpenAPI. """ if not cls.properties_schema: return None spec = APISpec( title="Database Parameters", version="1.0.0", openapi_version="3.0.2", plugins=[MarshmallowPlugin()], ) spec.components.schema(cls.__name__, schema=cls.properties_schema) return spec.to_dict()["components"]["schemas"][cls.__name__] @classmethod def get_default_catalog(cls, database: Database) -> str: """ Return the default catalog. It's optionally specified in `connect_args.catalog`. If not: The default behavior for Databricks is confusing. When Unity Catalog is not enabled we have (the DB engine spec hasn't been tested with it enabled): > SHOW CATALOGS; spark_catalog > SELECT current_catalog(); hive_metastore To handle permissions correctly we use the result of `SHOW CATALOGS` when a single catalog is returned. """ connect_args = cls.get_extra_params(database)["engine_params"]["connect_args"] if default_catalog := connect_args.get("catalog"): return default_catalog with database.get_sqla_engine() as engine: catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")} if len(catalogs) == 1: return catalogs.pop() return engine.execute("SELECT current_catalog()").scalar() @classmethod def get_prequeries( cls, database: Database, catalog: str | None = None, schema: str | None = None, ) -> list[str]: prequeries = [] if catalog: catalog = f"`{catalog}`" if not catalog.startswith("`") else catalog prequeries.append(f"USE CATALOG {catalog}") if schema: schema = f"`{schema}`" if not schema.startswith("`") else schema prequeries.append(f"USE SCHEMA {schema}") return prequeries @classmethod def get_catalog_names( cls, database: Database, inspector: Inspector, ) -> set[str]: return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} class DatabricksPythonConnectorEngineSpec(DatabricksDynamicBaseEngineSpec): engine = "databricks" engine_name = "Databricks" default_driver = "databricks-sql-python" drivers = {"databricks-sql-python": "Databricks SQL Python"} parameters_schema = DatabricksPythonConnectorSchema() sqlalchemy_uri_placeholder = ( "databricks://token:{access_token}@{host}:{port}?http_path={http_path}" "&catalog={default_catalog}&schema={default_schema}" ) context_key_mapping = { **DatabricksDynamicBaseEngineSpec.context_key_mapping, "default_catalog": "catalog", "default_schema": "schema", "http_path_field": "http_path", } required_parameters = DatabricksDynamicBaseEngineSpec.required_parameters | { "default_catalog", "default_schema", "http_path_field", } supports_dynamic_schema = supports_catalog = supports_dynamic_catalog = True @classmethod def build_sqlalchemy_uri( # type: ignore cls, parameters: DatabricksPythonConnectorParametersType, *_ ) -> str: query = {} if http_path := parameters.get("http_path_field"): query["http_path"] = http_path if catalog := parameters.get("default_catalog"): query["catalog"] = catalog if schema := parameters.get("default_schema"): query["schema"] = schema if parameters.get("encryption"): query.update(cls.encryption_parameters) return str( URL.create( cls.engine, username="token", password=parameters.get("access_token"), host=parameters["host"], port=parameters["port"], query=query, ) ) @classmethod def get_parameters_from_uri( # type: ignore cls, uri: str, *_: Any, **__: Any ) -> DatabricksPythonConnectorParametersType: url = make_url_safe(uri) query = { key: value for (key, value) in url.query.items() if (key, value) not in cls.encryption_parameters.items() } encryption = all( item in url.query.items() for item in cls.encryption_parameters.items() ) return { "access_token": url.password, "host": url.host, "port": url.port, "http_path_field": query["http_path"], "default_catalog": query["catalog"], "default_schema": query["schema"], "encryption": encryption, } @classmethod def get_default_catalog( cls, database: Database, ) -> str | None: return database.url_object.query.get("catalog") @classmethod def get_catalog_names( cls, database: Database, inspector: Inspector, ) -> set[str]: return {catalog for (catalog,) in inspector.bind.execute("SHOW CATALOGS")} @classmethod def adjust_engine_params( cls, uri: URL, connect_args: dict[str, Any], catalog: str | None = None, schema: str | None = None, ) -> tuple[URL, dict[str, Any]]: if catalog: uri = uri.update_query_dict({"catalog": catalog}) if schema: uri = uri.update_query_dict({"schema": schema}) return uri, connect_args