odps/superset_odps.py (228 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import logging import sys try: from sqlalchemy import Column except ImportError: Column = None try: from superset import sql_parse from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression from superset.exceptions import SupersetException except ImportError: # import fallback for tests only sql_parse = None class BaseEngineSpec(object): allows_sql_comments = True arraysize = 0 @classmethod def get_engine(cls, database, schema=None, source=None): return database.get_sqla_engine_with_context(schema=schema, source=source) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments cls, database, table, query, columns=None, ): pass @classmethod def get_table_names( # pylint: disable=unused-argument cls, database, inspector, schema, ): return set(inspector.get_table_names(schema)) @classmethod def get_dbapi_mapped_exception(cls, ex): return ex class SupersetException(Exception): pass try: from superset.constants import TimeGrain except ImportError: # compatibility for older superset versions class TimeGrain: SECOND = "PT1S" MINUTE = "PT1M" HOUR = "PT1H" DAY = "P1D" WEEK = "P1W" MONTH = "P1M" QUARTER = "P3M" YEAR = "P1Y" WEEK_ENDING_SATURDAY = "P1W/1970-01-03T00:00:00Z" WEEK_STARTING_SUNDAY = "1969-12-28T00:00:00Z/P1W" from .compat import getargspec, getfullargspec, six from .config import options from .df import DataFrame from .utils import TEMP_TABLE_PREFIX logger = logging.getLogger(__name__) if getfullargspec is None: getfullargspec = getargspec _builtin_funcs = set( """ ABS ACOS ADD_MONTHS ALL_MATCH ANY_MATCH ANY_VALUE ATAN2 APPROX_DISTINCT ARG_MAX ARG_MIN ARRAY ARRAY_CONTAINS ARRAY_DISTINCT ARRAY_EXCEPT ARRAY_INTERSECT ARRAY_JOIN ARRAY_MAX ARRAY_MIN ARRAY_NORMALIZE ARRAY_POSITION ARRAY_REDUCE ARRAY_REMOVE ARRAY_REPEAT ARRAY_SORT ARRAY_UNION ARRAYS_OVERLAP ARRAYS_ZIP ASCII ASIN ATAN AVG BASE64 BIN BITWISE_AND_AGG BITWISE_OR_AGG CAST CBRT CEIL CHAR_MATCHCOUNT CHR CLUSTER_SAMPLE COALESCE COLLECT_LIST COLLECT_SET COMBINATIONS COMPRESS CONCAT CONCAT_WS CONV CORR COS COSH COT COUNT COUNT_IF COVAR_POP COVAR_SAMP CRC32 CUME_DIST CURRENT_TIMESTAMP CURRENT_TIMEZONE DATE_ADD DATE_FORMAT DATE_SUB DATEADD DATEDIFF DATEPART DATETRUNC DAY DAYOFMONTH DAYOFWEEK DAYOFYEAR DECODE DECOMPRESS DEGREES DENSE_RANK E ENCODE EXP EXPLODE EXTRACT FACTORIAL FIELD FILTER FIRST_VALUE FIND_IN_SET FLATTEN FLOOR FORMAT_NUMBER FROM_JSON FROM_UNIXTIME FROM_UTC_TIMESTAMP GET_IDCARD_AGE GET_IDCARD_BIRTHDAY GET_IDCARD_SEX GET_JSON_OBJECT GET_USER_ID GETDATE GREATEST HASH HEX HISTOGRAM HOUR IF INDEX INLINE INITCAP INSTR IS_ENCODING ISDATE ISNAN JSON_OBJECT JSON_ARRAY JSON_EXTRACT JSON_EXISTS JSON_PRETTY JSON_TYPE JSON_FORMAT JSON_PARSE JSON_VALID JSON_TUPLE KEYVALUE KEYVALUE_TUPLE LAG LAST_DAY LASTDAY LAST_VALUE LEAD LEAST LENGTH LENGTHB LN LOCATE LOG LOG10 LOG2 LPAD LTRIM MAP MAP_AGG MAP_CONCAT MAP_ENTRIES MAP_FILTER MAP_FROM_ARRAYS MAP_FROM_ENTRIES MAP_KEYS MAP_UNION MAP_UNION_SUM MAP_VALUES MAP_ZIP_WITH MASK_HASH MAX MAX_BY MAX_PT MD5 MEDIAN MIN MIN_BY MINUTE MONTH MONTHS_BETWEEN MULTIMAP_AGG MULTIMAP_FROM_ENTRIES NAMED_STRUCT NEGATIVE NEXT_DAY NGRAMS NOW NTILE NTH_VALUE NULLIF NUMERIC_HISTOGRAM NVL ORDINAL PARSE_URL PARSE_URL_TUPLE PARTITION_EXISTS PERCENT_RANK PERCENTILE PERCENTILE_APPROX PI POSEXPLODE POSITIVE POW QUARTER RADIANS RAND RANK REGEXP_COUNT REGEXP_EXTRACT REGEXP_EXTRACT_ALL REGEXP_INSTR REGEXP_REPLACE REGEXP_SUBSTR REPEAT REPLACE REVERSE ROUND ROW_NUMBER RPAD RTRIM SAMPLE SECOND SEQUENCE SHA SHA1 SHA2 SHIFTLEFT SHIFTRIGHT SHIFTRIGHTUNSIGNED SHUFFLE SIGN SIN SINH SIZE SLICE SORT_ARRAY SOUNDEX SPACE SPLIT SPLIT_PART SQRT STACK STDDEV STDDEV_SAMP STR_TO_MAP STRUCT SUBSTR SUBSTRING SUBSTRING_INDEX SUM SYM_DECRYPT SYM_ENCRYPT TABLE_EXISTS TAN TANH TO_CHAR TO_DATE TO_JSON TO_MILLIS TOLOWER TOUPPER TRANS_ARRAY TRANS_COLS TRANSFORM TRANSFORM_KEYS TRANSFORM_VALUES TRANSLATE TRIM TRUNC UNBASE64 UNHEX UNIQUE_ID UNIX_TIMESTAMP URL_DECODE URL_ENCODE UUID VAR_SAMP VARIANCE/VAR_POP WEEKDAY WEEKOFYEAR WIDTH_BUCKET WM_CONCAT YEAR ZIP_WITH """.strip().split() ) class ODPSEngineSpec(BaseEngineSpec): engine = "odps" engine_aliases = {"maxcompute"} engine_name = "ODPS" # pylint: disable=line-too-long _time_grain_expressions = { None: "{col}", TimeGrain.SECOND: "datetrunc({col}, 'ss')", TimeGrain.MINUTE: "datetrunc({col}, 'mi')", TimeGrain.HOUR: "datetrunc({col}, 'hh')", TimeGrain.DAY: "datetrunc({col}, 'dd')", TimeGrain.WEEK: "datetrunc(dateadd({col}, 1 - dayofweek({col}), 'dd'), 'dd')", TimeGrain.MONTH: "datetrunc({col}, 'month')", TimeGrain.QUARTER: "datetrunc(dateadd({col}, -3, 'mm'), 'dd')", TimeGrain.YEAR: "datetrunc({col}, 'yyyy')", TimeGrain.WEEK_ENDING_SATURDAY: "datetrunc(dateadd({col}, 6 - dayofweek({col}), 'dd'), 'dd')", TimeGrain.WEEK_STARTING_SUNDAY: "datetrunc(dateadd({col}, 7 - dayofweek({col}), 'dd'), 'dd')", } _py_format_to_odps_sql_format = [ ("%Y", "YYYY"), ("%m", "MM"), ("%d", "DD"), ("%H", "HH"), ("%M", "MI"), ("%S", "SS"), ("%%", "%"), ] @classmethod def get_timestamp_expr(cls, col, pdf, time_grain): time_expr = ( super(ODPSEngineSpec, cls).get_timestamp_expr(col, pdf, time_grain).key ) for pat, sub in cls._py_format_to_odps_sql_format: pdf = pdf.replace(pat, sub) return TimestampExpression(time_expr, col, type_=col.type) @classmethod @contextlib.contextmanager def _get_database_engine(cls, database): en = cls.get_engine(database) try: if hasattr(en, "__enter__"): engine = en.__enter__() else: engine = en yield engine finally: if hasattr(en, "__exit__"): en.__exit__(*sys.exc_info()) @classmethod def _get_odps_entry(cls, database): with cls._get_database_engine(database) as engine: return engine.dialect.get_odps_from_url(engine.url) @classmethod def get_catalog_names( # pylint: disable=unused-argument cls, database, inspector, ): engine = inspector.engine odps_entry = engine.dialect.get_odps_from_url(engine.url) try: return [proj.name for proj in odps_entry.list_projects()] except: return [odps_entry.project] @classmethod def _where_latest_partition( # pylint: disable=too-many-arguments cls, table_name, schema, database, query, columns=None, ): odps_entry = cls._get_odps_entry(database) table = odps_entry.get_table(table_name, schema=schema) if not table.schema.partitions: return None max_pt = table.get_max_partition() if columns is None or not max_pt: return None res_cols = set(c.get("name") for c in columns) for col_name, value in max_pt.partition_spec.items(): if col_name in res_cols: query = query.where(Column(col_name) == value) return query if "schema" in getfullargspec(BaseEngineSpec.where_latest_partition).args: # superset prior to v4.1.0 uses a legacy call routine where_latest_partition = _where_latest_partition else: @classmethod def where_latest_partition(cls, database, table, query, columns=None): return cls._where_latest_partition( table.table, table.schema, database, query, columns ) @classmethod def select_star(cls, *args, **kw): if "indent" in kw: # suspend indentation to circumvent calls to sqlglot kw["indent"] = False return super(ODPSEngineSpec, cls).select_star(*args, **kw) @classmethod def latest_sub_partition( # type: ignore cls, table_name, schema, database, **kwargs ): # TODO: implement pass @classmethod def get_table_names(cls, database, inspector, schema): logger.info("Start listing tables for schema %s", schema) tables = super(ODPSEngineSpec, cls).get_table_names(database, inspector, schema) return set([n for n in tables if not n.startswith(TEMP_TABLE_PREFIX)]) @classmethod def get_function_names(cls, database): with cls._get_database_engine(database) as engine: cached = engine.dialect.get_list_cache(engine.url, ("functions",)) if cached is not None: return cached odps_entry = cls._get_odps_entry(database) funcs = set( [ func.name for func in odps_entry.list_functions() if not func.name.startswith("pyodps_") ] ) funcs = sorted(funcs | _builtin_funcs) with cls._get_database_engine(database) as engine: engine.dialect.put_list_cache(engine.url, ("functions",), funcs) return funcs @classmethod def execute(cls, cursor, query, database=None, **kwargs): options.verbose = True if not cls.allows_sql_comments: query = sql_parse.strip_comments_from_sql(query) if cls.arraysize: cursor.arraysize = cls.arraysize try: hints = { "odps.sql.jobconf.odps2": "true", } conn_project_as_schema = getattr( cursor.connection, "_project_as_schema", None ) conn_project_as_schema = ( True if conn_project_as_schema is None else conn_project_as_schema ) if not conn_project_as_schema: # sqlalchemy cursor need odps schema support hints.update( { "odps.sql.allow.namespace.schema": "true", "odps.namespace.schema": "true", } ) cursor.execute(query, hints=hints) except Exception as ex: six.raise_from(cls.get_dbapi_mapped_exception(ex), ex) @classmethod def df_to_sql(cls, database, table, df, to_sql_kwargs): options.verbose = True odps_entry = cls._get_odps_entry(database) if to_sql_kwargs["if_exists"] == "fail": # Ensure table doesn't already exist. if odps_entry.exist_table(table.table, schema=table.schema): raise SupersetException("Table already exists") elif to_sql_kwargs["if_exists"] == "replace": odps_entry.delete_table(table.table, schema=table.schema, if_exists=True) odps_df = DataFrame(df) odps_df.persist(table.table, overwrite=False, odps=odps_entry)