odps/superset_odps.py (228 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import logging
import sys
try:
from sqlalchemy import Column
except ImportError:
Column = None
try:
from superset import sql_parse
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.exceptions import SupersetException
except ImportError:
# import fallback for tests only
sql_parse = None
class BaseEngineSpec(object):
allows_sql_comments = True
arraysize = 0
@classmethod
def get_engine(cls, database, schema=None, source=None):
return database.get_sqla_engine_with_context(schema=schema, source=source)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
database,
table,
query,
columns=None,
):
pass
@classmethod
def get_table_names( # pylint: disable=unused-argument
cls,
database,
inspector,
schema,
):
return set(inspector.get_table_names(schema))
@classmethod
def get_dbapi_mapped_exception(cls, ex):
return ex
class SupersetException(Exception):
pass
try:
from superset.constants import TimeGrain
except ImportError:
# compatibility for older superset versions
class TimeGrain:
SECOND = "PT1S"
MINUTE = "PT1M"
HOUR = "PT1H"
DAY = "P1D"
WEEK = "P1W"
MONTH = "P1M"
QUARTER = "P3M"
YEAR = "P1Y"
WEEK_ENDING_SATURDAY = "P1W/1970-01-03T00:00:00Z"
WEEK_STARTING_SUNDAY = "1969-12-28T00:00:00Z/P1W"
from .compat import getargspec, getfullargspec, six
from .config import options
from .df import DataFrame
from .utils import TEMP_TABLE_PREFIX
logger = logging.getLogger(__name__)
if getfullargspec is None:
getfullargspec = getargspec
_builtin_funcs = set(
"""
ABS ACOS ADD_MONTHS ALL_MATCH ANY_MATCH ANY_VALUE ATAN2 APPROX_DISTINCT
ARG_MAX ARG_MIN ARRAY ARRAY_CONTAINS ARRAY_DISTINCT ARRAY_EXCEPT
ARRAY_INTERSECT ARRAY_JOIN ARRAY_MAX ARRAY_MIN ARRAY_NORMALIZE
ARRAY_POSITION ARRAY_REDUCE ARRAY_REMOVE ARRAY_REPEAT ARRAY_SORT
ARRAY_UNION ARRAYS_OVERLAP ARRAYS_ZIP ASCII ASIN ATAN AVG BASE64
BIN BITWISE_AND_AGG BITWISE_OR_AGG CAST CBRT CEIL CHAR_MATCHCOUNT
CHR CLUSTER_SAMPLE COALESCE COLLECT_LIST COLLECT_SET COMBINATIONS
COMPRESS CONCAT CONCAT_WS CONV CORR COS COSH COT COUNT COUNT_IF
COVAR_POP COVAR_SAMP CRC32 CUME_DIST CURRENT_TIMESTAMP CURRENT_TIMEZONE
DATE_ADD DATE_FORMAT DATE_SUB DATEADD DATEDIFF DATEPART DATETRUNC DAY
DAYOFMONTH DAYOFWEEK DAYOFYEAR DECODE DECOMPRESS DEGREES DENSE_RANK
E ENCODE EXP EXPLODE EXTRACT FACTORIAL FIELD FILTER FIRST_VALUE
FIND_IN_SET FLATTEN FLOOR FORMAT_NUMBER FROM_JSON FROM_UNIXTIME
FROM_UTC_TIMESTAMP GET_IDCARD_AGE GET_IDCARD_BIRTHDAY GET_IDCARD_SEX
GET_JSON_OBJECT GET_USER_ID GETDATE GREATEST HASH HEX HISTOGRAM
HOUR IF INDEX INLINE INITCAP INSTR IS_ENCODING ISDATE ISNAN JSON_OBJECT
JSON_ARRAY JSON_EXTRACT JSON_EXISTS JSON_PRETTY JSON_TYPE JSON_FORMAT
JSON_PARSE JSON_VALID JSON_TUPLE KEYVALUE KEYVALUE_TUPLE LAG
LAST_DAY LASTDAY LAST_VALUE LEAD LEAST LENGTH LENGTHB LN LOCATE
LOG LOG10 LOG2 LPAD LTRIM MAP MAP_AGG MAP_CONCAT MAP_ENTRIES
MAP_FILTER MAP_FROM_ARRAYS MAP_FROM_ENTRIES MAP_KEYS MAP_UNION
MAP_UNION_SUM MAP_VALUES MAP_ZIP_WITH MASK_HASH MAX MAX_BY MAX_PT
MD5 MEDIAN MIN MIN_BY MINUTE MONTH MONTHS_BETWEEN MULTIMAP_AGG
MULTIMAP_FROM_ENTRIES NAMED_STRUCT NEGATIVE NEXT_DAY NGRAMS NOW
NTILE NTH_VALUE NULLIF NUMERIC_HISTOGRAM NVL ORDINAL PARSE_URL
PARSE_URL_TUPLE PARTITION_EXISTS PERCENT_RANK PERCENTILE
PERCENTILE_APPROX PI POSEXPLODE POSITIVE POW QUARTER RADIANS
RAND RANK REGEXP_COUNT REGEXP_EXTRACT REGEXP_EXTRACT_ALL REGEXP_INSTR
REGEXP_REPLACE REGEXP_SUBSTR REPEAT REPLACE REVERSE ROUND ROW_NUMBER
RPAD RTRIM SAMPLE SECOND SEQUENCE SHA SHA1 SHA2 SHIFTLEFT SHIFTRIGHT
SHIFTRIGHTUNSIGNED SHUFFLE SIGN SIN SINH SIZE SLICE SORT_ARRAY
SOUNDEX SPACE SPLIT SPLIT_PART SQRT STACK STDDEV STDDEV_SAMP
STR_TO_MAP STRUCT SUBSTR SUBSTRING SUBSTRING_INDEX SUM
SYM_DECRYPT SYM_ENCRYPT TABLE_EXISTS TAN TANH TO_CHAR TO_DATE
TO_JSON TO_MILLIS TOLOWER TOUPPER TRANS_ARRAY TRANS_COLS
TRANSFORM TRANSFORM_KEYS TRANSFORM_VALUES TRANSLATE TRIM TRUNC
UNBASE64 UNHEX UNIQUE_ID UNIX_TIMESTAMP URL_DECODE URL_ENCODE
UUID VAR_SAMP VARIANCE/VAR_POP WEEKDAY WEEKOFYEAR WIDTH_BUCKET
WM_CONCAT YEAR ZIP_WITH
""".strip().split()
)
class ODPSEngineSpec(BaseEngineSpec):
engine = "odps"
engine_aliases = {"maxcompute"}
engine_name = "ODPS"
# pylint: disable=line-too-long
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "datetrunc({col}, 'ss')",
TimeGrain.MINUTE: "datetrunc({col}, 'mi')",
TimeGrain.HOUR: "datetrunc({col}, 'hh')",
TimeGrain.DAY: "datetrunc({col}, 'dd')",
TimeGrain.WEEK: "datetrunc(dateadd({col}, 1 - dayofweek({col}), 'dd'), 'dd')",
TimeGrain.MONTH: "datetrunc({col}, 'month')",
TimeGrain.QUARTER: "datetrunc(dateadd({col}, -3, 'mm'), 'dd')",
TimeGrain.YEAR: "datetrunc({col}, 'yyyy')",
TimeGrain.WEEK_ENDING_SATURDAY: "datetrunc(dateadd({col}, 6 - dayofweek({col}), 'dd'), 'dd')",
TimeGrain.WEEK_STARTING_SUNDAY: "datetrunc(dateadd({col}, 7 - dayofweek({col}), 'dd'), 'dd')",
}
_py_format_to_odps_sql_format = [
("%Y", "YYYY"),
("%m", "MM"),
("%d", "DD"),
("%H", "HH"),
("%M", "MI"),
("%S", "SS"),
("%%", "%"),
]
@classmethod
def get_timestamp_expr(cls, col, pdf, time_grain):
time_expr = (
super(ODPSEngineSpec, cls).get_timestamp_expr(col, pdf, time_grain).key
)
for pat, sub in cls._py_format_to_odps_sql_format:
pdf = pdf.replace(pat, sub)
return TimestampExpression(time_expr, col, type_=col.type)
@classmethod
@contextlib.contextmanager
def _get_database_engine(cls, database):
en = cls.get_engine(database)
try:
if hasattr(en, "__enter__"):
engine = en.__enter__()
else:
engine = en
yield engine
finally:
if hasattr(en, "__exit__"):
en.__exit__(*sys.exc_info())
@classmethod
def _get_odps_entry(cls, database):
with cls._get_database_engine(database) as engine:
return engine.dialect.get_odps_from_url(engine.url)
@classmethod
def get_catalog_names( # pylint: disable=unused-argument
cls,
database,
inspector,
):
engine = inspector.engine
odps_entry = engine.dialect.get_odps_from_url(engine.url)
try:
return [proj.name for proj in odps_entry.list_projects()]
except:
return [odps_entry.project]
@classmethod
def _where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name,
schema,
database,
query,
columns=None,
):
odps_entry = cls._get_odps_entry(database)
table = odps_entry.get_table(table_name, schema=schema)
if not table.schema.partitions:
return None
max_pt = table.get_max_partition()
if columns is None or not max_pt:
return None
res_cols = set(c.get("name") for c in columns)
for col_name, value in max_pt.partition_spec.items():
if col_name in res_cols:
query = query.where(Column(col_name) == value)
return query
if "schema" in getfullargspec(BaseEngineSpec.where_latest_partition).args:
# superset prior to v4.1.0 uses a legacy call routine
where_latest_partition = _where_latest_partition
else:
@classmethod
def where_latest_partition(cls, database, table, query, columns=None):
return cls._where_latest_partition(
table.table, table.schema, database, query, columns
)
@classmethod
def select_star(cls, *args, **kw):
if "indent" in kw:
# suspend indentation to circumvent calls to sqlglot
kw["indent"] = False
return super(ODPSEngineSpec, cls).select_star(*args, **kw)
@classmethod
def latest_sub_partition( # type: ignore
cls, table_name, schema, database, **kwargs
):
# TODO: implement
pass
@classmethod
def get_table_names(cls, database, inspector, schema):
logger.info("Start listing tables for schema %s", schema)
tables = super(ODPSEngineSpec, cls).get_table_names(database, inspector, schema)
return set([n for n in tables if not n.startswith(TEMP_TABLE_PREFIX)])
@classmethod
def get_function_names(cls, database):
with cls._get_database_engine(database) as engine:
cached = engine.dialect.get_list_cache(engine.url, ("functions",))
if cached is not None:
return cached
odps_entry = cls._get_odps_entry(database)
funcs = set(
[
func.name
for func in odps_entry.list_functions()
if not func.name.startswith("pyodps_")
]
)
funcs = sorted(funcs | _builtin_funcs)
with cls._get_database_engine(database) as engine:
engine.dialect.put_list_cache(engine.url, ("functions",), funcs)
return funcs
@classmethod
def execute(cls, cursor, query, database=None, **kwargs):
options.verbose = True
if not cls.allows_sql_comments:
query = sql_parse.strip_comments_from_sql(query)
if cls.arraysize:
cursor.arraysize = cls.arraysize
try:
hints = {
"odps.sql.jobconf.odps2": "true",
}
conn_project_as_schema = getattr(
cursor.connection, "_project_as_schema", None
)
conn_project_as_schema = (
True if conn_project_as_schema is None else conn_project_as_schema
)
if not conn_project_as_schema:
# sqlalchemy cursor need odps schema support
hints.update(
{
"odps.sql.allow.namespace.schema": "true",
"odps.namespace.schema": "true",
}
)
cursor.execute(query, hints=hints)
except Exception as ex:
six.raise_from(cls.get_dbapi_mapped_exception(ex), ex)
@classmethod
def df_to_sql(cls, database, table, df, to_sql_kwargs):
options.verbose = True
odps_entry = cls._get_odps_entry(database)
if to_sql_kwargs["if_exists"] == "fail":
# Ensure table doesn't already exist.
if odps_entry.exist_table(table.table, schema=table.schema):
raise SupersetException("Table already exists")
elif to_sql_kwargs["if_exists"] == "replace":
odps_entry.delete_table(table.table, schema=table.schema, if_exists=True)
odps_df = DataFrame(df)
odps_df.persist(table.table, overwrite=False, odps=odps_entry)