odps/sqlalchemy_odps.py (470 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import itertools
import sys
import threading
import time
from sqlalchemy import types as sa_types
from sqlalchemy.engine import Engine, default
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.sql import compiler, sqltypes
try:
from sqlalchemy.dialects import mysql
except ImportError:
# for low sqlalchemy versions
from sqlalchemy.databases import mysql
from . import options, types
from .compat import six
from .core import DEFAULT_ENDPOINT, ODPS
from .errors import BaseODPSError, InternalServerError, NoSuchObject
from .models import Table
from .models.session.v1 import PUBLIC_SESSION_NAME
from .utils import to_str, to_text
test_setting = threading.local()
test_setting.get_tables_filter = None
@contextlib.contextmanager
def update_test_setting(**kw):
old_values = {}
for k in kw:
old_values[k] = getattr(test_setting, k)
for k, v in six.iteritems(kw):
setattr(test_setting, k, v)
yield
# set back value
for k, v in six.iteritems(old_values):
setattr(test_setting, k, v)
_odps_type_to_sqlalchemy_type = {
types.Boolean: sa_types.Boolean,
types.Tinyint: mysql.TINYINT,
types.Smallint: sa_types.SmallInteger,
types.Int: sa_types.Integer,
types.Bigint: sa_types.BigInteger,
types.Float: sa_types.Float,
types.Double: sa_types.Float,
types.String: sa_types.String,
types.Varchar: sa_types.String,
types.Char: sa_types.String,
types.Date: sa_types.Date,
types.Datetime: sa_types.DateTime,
types.Timestamp: sa_types.TIMESTAMP,
types.Binary: sa_types.String,
types.Array: sa_types.String,
types.Map: sa_types.String,
types.Struct: sa_types.String,
types.Decimal: sa_types.DECIMAL,
types.Json: sa_types.String,
types.TimestampNTZ: sa_types.TIMESTAMP,
}
_sqlalchemy_global_reusable_odps = {}
_sqlalchemy_obj_list_cache = {}
class ObjectCache(object):
def __init__(self, expire=24 * 3600):
self._expire_time = expire
self._items = dict()
self._cache_time = dict()
def __getitem__(self, key):
if self._cache_time[key] < time.time() - self._expire_time:
self._cache_time.pop(key, None)
self._items.pop(key, None)
raise KeyError(key)
return self._items[key]
def __setitem__(self, key, value):
self._items[key] = value
self._cache_time[key] = time.time()
def get(self, key, default=None):
try:
return self[key]
except KeyError:
return default
class ODPSIdentifierPreparer(compiler.IdentifierPreparer):
# Just quote everything to make things simpler / easier to upgrade
reserved_words = compiler.RESERVED_WORDS.copy()
keywords = [
"ADD",
"ALL",
"ALTER",
"AND",
"AS",
"ASC",
"BETWEEN",
"BIGINT",
"BOOLEAN",
"BY",
"CASE",
"CAST",
"COLUMN",
"COMMENT",
"CREATE",
"DESC",
"DISTINCT",
"DISTRIBUTE",
"DOUBLE",
"DROP",
"ELSE",
"FALSE",
"FROM",
"FULL",
"GROUP",
"IF",
"IN",
"INSERT",
"INTO",
"IS",
"JOIN",
"LEFT",
"LIFECYCLE",
"LIKE",
"LIMIT",
"MAPJOIN",
"NOT",
"NULL",
"ON",
"OR",
"ORDER",
"OUTER",
"OVERWRITE",
"PARTITION",
"RENAME",
"REPLACE",
"RIGHT",
"RLIKE",
"SELECT",
"SORT",
"STRING",
"TABLE",
"TABLESAMPLE",
"TBLPROPERTIES",
"THEN",
"TOUCH",
"TRUE",
"UNION",
"VIEW",
"WHEN",
"WHERE",
]
reserved_words.update(keywords)
reserved_words.update([s.lower() for s in keywords])
def __init__(self, dialect):
super(ODPSIdentifierPreparer, self).__init__(
dialect,
initial_quote="`",
escape_quote="`",
)
def quote(self, ident, force=None):
return to_str(super(ODPSIdentifierPreparer, self).quote(ident, force=force))
class ODPSCompiler(compiler.SQLCompiler):
def visit_column(self, *args, **kwargs):
result = super(ODPSCompiler, self).visit_column(*args, **kwargs)
dot_count = result.count(".")
assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format(
result
)
if dot_count == 2:
# we have something of the form schema.table.column
# hive doesn't like the schema in front, so chop it out
result = result[result.index(".") + 1 :]
return result
def visit_char_length_func(self, fn, **kw):
return "length{}".format(self.function_argspec(fn, **kw))
def __unicode__(self):
return to_text(self)
class ODPSTypeCompiler(compiler.GenericTypeCompiler):
def visit_INTEGER(self, type_):
return "INT"
def visit_NUMERIC(self, type_):
return "DECIMAL"
def visit_CHAR(self, type_):
return "STRING"
def visit_VARCHAR(self, type_):
return "STRING"
def visit_NCHAR(self, type_):
return "STRING"
def visit_TEXT(self, type_):
return "STRING"
def visit_CLOB(self, type_):
return "STRING"
def visit_BLOB(self, type_):
return "BINARY"
def visit_TIME(self, type_):
return "TIMESTAMP"
if hasattr(sqltypes.String, "RETURNS_UNICODE"):
_return_unicode_str = sqltypes.String.RETURNS_UNICODE
else:
_return_unicode_str = True
class ODPSPingError(BaseODPSError):
pass
class ODPSDialect(default.DefaultDialect):
name = "odps"
driver = "rest"
preparer = ODPSIdentifierPreparer
statement_compiler = ODPSCompiler
supports_views = True
supports_alter = True
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_native_decimal = True
supports_native_boolean = True
supports_unicode_statements = True
supports_unicode_binds = True
returns_unicode_strings = _return_unicode_str
description_encoding = None
supports_multivalues_insert = True
type_compiler = ODPSTypeCompiler
supports_sane_rowcount = False
supports_statement_cache = False
_reused_odps = None
default_schema_name = "default"
@classmethod
def dbapi(cls):
from . import dbapi
return dbapi
def create_connect_args(self, url):
url_string = str(url)
project = url.host
if project is None and options.default_project:
project = options.default_project
access_id = url.username
secret_access_key = url.password
logview_host = options.logview_host
endpoint = None
session_name = None
sqa_type = False
quota_name = None
reuse_odps = False
project_as_schema = False
fallback_policy = ""
cache_names = False
cache_seconds = 24 * 3600
hints = {}
if url.query:
query = dict(url.query)
if endpoint is None:
endpoint = query.pop("endpoint", None)
if logview_host is None:
logview_host = query.pop("logview_host", query.pop("logview", None))
if session_name is None:
session_name = query.pop("session", None)
if quota_name is None:
quota_name = query.pop("quota_name", None)
if sqa_type is False:
sqa_type = query.pop("interactive_mode", "false").lower()
if sqa_type == "true":
sqa_type = "v1"
elif sqa_type == "false":
sqa_type = False
if reuse_odps is False:
reuse_odps = query.pop("reuse_odps", "false").lower() != "false"
if query.get("project_as_schema", None) is not None:
project_as_schema = (
query.pop("project_as_schema", "false").lower() != "false"
)
if fallback_policy == "":
fallback_policy = query.pop("fallback_policy", "default")
if cache_names is False:
cache_names = query.pop("cache_names", "false").lower() != "false"
cache_seconds = int(query.pop("cache_seconds", cache_seconds))
hints = query
if endpoint is None:
endpoint = options.endpoint or DEFAULT_ENDPOINT
if session_name is None:
session_name = PUBLIC_SESSION_NAME
kwargs = {
"access_id": access_id,
"secret_access_key": secret_access_key,
"project": project,
"endpoint": endpoint,
"session_name": session_name,
"use_sqa": sqa_type,
"fallback_policy": fallback_policy,
"project_as_schema": project_as_schema,
"hints": hints,
}
if quota_name is not None:
kwargs["quota_name"] = quota_name
if access_id is None:
kwargs.pop("access_id", None)
kwargs.pop("secret_access_key", None)
kwargs["account"] = options.account
for k, v in six.iteritems(kwargs):
if v is None:
raise ValueError(
"{} should be provided to create connection, "
"you can either specify in connection string as format: "
'"odps://<access_id>:<access_key>@<project_name>", '
"or create an ODPS object and call `.to_global()` "
"to set it to global".format(k)
)
if logview_host is not None:
kwargs["logview_host"] = logview_host
if cache_names:
_sqlalchemy_obj_list_cache[url_string] = ObjectCache(expire=cache_seconds)
if reuse_odps:
# the odps object can only be reused only if it will be identical
if (
url_string in _sqlalchemy_global_reusable_odps
and _sqlalchemy_global_reusable_odps.get(url_string) is not None
):
kwargs["odps"] = _sqlalchemy_global_reusable_odps.get(url_string)
kwargs["access_id"] = None
kwargs["secret_access_key"] = None
else:
_sqlalchemy_global_reusable_odps[url_string] = ODPS(
access_id=access_id,
secret_access_key=secret_access_key,
project=project,
endpoint=endpoint,
logview_host=logview_host,
)
return [], kwargs
def get_odps_from_url(self, url):
_, kwargs = self.create_connect_args(url)
if "odps" in kwargs:
return kwargs["odps"]
odps_kw = kwargs.copy()
odps_kw.pop("session_name", None)
odps_kw.pop("use_sqa", None)
odps_kw.pop("fallback_policy", None)
odps_kw.pop("hints", None)
odps_kw.pop("project_as_schema", None)
odps_kw["overwrite_global"] = False
return ODPS(**odps_kw)
@classmethod
def get_list_cache(cls, url, key):
url = str(url)
if url not in _sqlalchemy_obj_list_cache:
return None
return _sqlalchemy_obj_list_cache[url].get(key)
@classmethod
def put_list_cache(cls, url, key, value):
url = str(url)
if url not in _sqlalchemy_obj_list_cache:
return
_sqlalchemy_obj_list_cache[url][key] = value
def get_schema_names(self, connection, **kw):
conn = self._get_dbapi_connection(connection)
if getattr(conn, "_project_as_schema", False):
fields = ["owner", "user", "group", "prefix"]
if (conn.odps.project is None) or (kw.pop("listall", None) is not None):
kwargs = {f: kw.get(f) for f in fields}
return [proj.name for proj in conn.odps.list_projects(**kwargs)]
else:
return [conn.odps.project]
else:
try:
return [schema.name for schema in conn.odps.list_schemas()]
except:
return ["default"]
def has_table(self, connection, table_name, schema=None, **kw):
conn = self._get_dbapi_connection(connection)
schema_kw = self._get_schema_kw(connection, schema=schema)
return conn.odps.exist_table(table_name, **schema_kw)
@classmethod
def _get_dbapi_connection(cls, sa_connection):
if isinstance(sa_connection, Engine):
sa_connection = sa_connection.connect()
return sa_connection.connection.connection
@classmethod
def _get_schema_kw(cls, connection, schema=None):
db_conn = cls._get_dbapi_connection(connection)
if getattr(db_conn, "_project_as_schema", False):
return dict(project=schema)
else:
return dict(schema=schema)
def get_columns(self, connection, table_name, schema=None, **kw):
conn = self._get_dbapi_connection(connection)
schema_kw = self._get_schema_kw(connection, schema=schema)
table = conn.odps.get_table(table_name, **schema_kw)
result = []
try:
for col in table.table_schema.columns:
col_type = _odps_type_to_sqlalchemy_type[type(col.type)]
result.append(
{
"name": col.name,
"type": col_type,
"nullable": True,
"default": None,
"comment": col.comment,
}
)
except NoSuchObject as e:
# convert ODPSError to SQLAlchemy NoSuchTableError
raise NoSuchTableError(str(e))
return result
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# ODPS has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# ODPS has no support for primary keys.
return []
def get_indexes(self, connection, table_name, schema=None, **kw):
# ODPS has no support for indexes
return []
def _iter_tables(self, connection, schema=None, types=None, **kw):
cache_key = ("tables", schema, tuple(types))
cached = self.get_list_cache(connection.engine.url, cache_key)
if cached is not None:
return cached
conn = self._get_dbapi_connection(connection)
filter_ = getattr(test_setting, "get_tables_filter", None)
if filter_ is None:
filter_ = lambda x: True
schema_kw = self._get_schema_kw(connection, schema=schema)
if not types:
it = conn.odps.list_tables(**schema_kw)
else:
its = []
for table_type in types:
list_kw = schema_kw.copy()
list_kw["type"] = table_type
its.append(conn.odps.list_tables(**list_kw))
it = itertools.chain(*its)
result = [t.name for t in it if filter_(t.name)]
self.put_list_cache(connection.engine.url, cache_key, result)
return result
def get_table_names(self, connection, schema=None, **kw):
return self._iter_tables(
connection,
schema=schema,
types=[Table.Type.MANAGED_TABLE, Table.Type.EXTERNAL_TABLE],
**kw
)
def get_view_names(self, connection, schema=None, **kw):
return self._iter_tables(
connection,
schema=schema,
types=[Table.Type.VIRTUAL_VIEW, Table.Type.MATERIALIZED_VIEW],
**kw
)
def get_table_comment(self, connection, table_name, schema=None, **kw):
conn = self._get_dbapi_connection(connection)
schema_kw = self._get_schema_kw(connection, schema=schema)
comment = conn.odps.get_table(table_name, **schema_kw).comment
return {"text": comment}
@classmethod
def _is_stack_superset(cls, tb):
try:
cur_frame = tb.tb_frame
while cur_frame is not None:
if "superset" in cur_frame.f_code.co_filename:
return True
cur_frame = cur_frame.f_back
return False
except: # pragma: no cover
return False
def do_ping(self, dbapi_connection):
"""Stop raising RuntimeError when ping by Superset"""
try:
return super(ODPSDialect, self).do_ping(dbapi_connection)
except InternalServerError:
raise
except BaseException as ex:
_, _, tb = sys.exc_info()
if not self._is_stack_superset(tb):
raise
new_err = ODPSPingError(ex.args[0])
for attr in (
"request_id",
"instance_id",
"code",
"host_id",
"endpoint",
"tag",
):
setattr(new_err, attr, getattr(ex, attr))
six.reraise(ODPSPingError, new_err, tb)
def do_rollback(self, dbapi_connection):
# No transactions for ODPS
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# We decode everything as UTF-8
return True
def _check_unicode_description(self, connection):
# We decode everything as UTF-8
return True