odps/sqlalchemy_odps.py (470 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import contextlib import itertools import sys import threading import time from sqlalchemy import types as sa_types from sqlalchemy.engine import Engine, default from sqlalchemy.exc import NoSuchTableError from sqlalchemy.sql import compiler, sqltypes try: from sqlalchemy.dialects import mysql except ImportError: # for low sqlalchemy versions from sqlalchemy.databases import mysql from . import options, types from .compat import six from .core import DEFAULT_ENDPOINT, ODPS from .errors import BaseODPSError, InternalServerError, NoSuchObject from .models import Table from .models.session.v1 import PUBLIC_SESSION_NAME from .utils import to_str, to_text test_setting = threading.local() test_setting.get_tables_filter = None @contextlib.contextmanager def update_test_setting(**kw): old_values = {} for k in kw: old_values[k] = getattr(test_setting, k) for k, v in six.iteritems(kw): setattr(test_setting, k, v) yield # set back value for k, v in six.iteritems(old_values): setattr(test_setting, k, v) _odps_type_to_sqlalchemy_type = { types.Boolean: sa_types.Boolean, types.Tinyint: mysql.TINYINT, types.Smallint: sa_types.SmallInteger, types.Int: sa_types.Integer, types.Bigint: sa_types.BigInteger, types.Float: sa_types.Float, types.Double: sa_types.Float, types.String: sa_types.String, types.Varchar: sa_types.String, types.Char: sa_types.String, types.Date: sa_types.Date, types.Datetime: sa_types.DateTime, types.Timestamp: sa_types.TIMESTAMP, types.Binary: sa_types.String, types.Array: sa_types.String, types.Map: sa_types.String, types.Struct: sa_types.String, types.Decimal: sa_types.DECIMAL, types.Json: sa_types.String, types.TimestampNTZ: sa_types.TIMESTAMP, } _sqlalchemy_global_reusable_odps = {} _sqlalchemy_obj_list_cache = {} class ObjectCache(object): def __init__(self, expire=24 * 3600): self._expire_time = expire self._items = dict() self._cache_time = dict() def __getitem__(self, key): if self._cache_time[key] < time.time() - self._expire_time: self._cache_time.pop(key, None) self._items.pop(key, None) raise KeyError(key) return self._items[key] def __setitem__(self, key, value): self._items[key] = value self._cache_time[key] = time.time() def get(self, key, default=None): try: return self[key] except KeyError: return default class ODPSIdentifierPreparer(compiler.IdentifierPreparer): # Just quote everything to make things simpler / easier to upgrade reserved_words = compiler.RESERVED_WORDS.copy() keywords = [ "ADD", "ALL", "ALTER", "AND", "AS", "ASC", "BETWEEN", "BIGINT", "BOOLEAN", "BY", "CASE", "CAST", "COLUMN", "COMMENT", "CREATE", "DESC", "DISTINCT", "DISTRIBUTE", "DOUBLE", "DROP", "ELSE", "FALSE", "FROM", "FULL", "GROUP", "IF", "IN", "INSERT", "INTO", "IS", "JOIN", "LEFT", "LIFECYCLE", "LIKE", "LIMIT", "MAPJOIN", "NOT", "NULL", "ON", "OR", "ORDER", "OUTER", "OVERWRITE", "PARTITION", "RENAME", "REPLACE", "RIGHT", "RLIKE", "SELECT", "SORT", "STRING", "TABLE", "TABLESAMPLE", "TBLPROPERTIES", "THEN", "TOUCH", "TRUE", "UNION", "VIEW", "WHEN", "WHERE", ] reserved_words.update(keywords) reserved_words.update([s.lower() for s in keywords]) def __init__(self, dialect): super(ODPSIdentifierPreparer, self).__init__( dialect, initial_quote="`", escape_quote="`", ) def quote(self, ident, force=None): return to_str(super(ODPSIdentifierPreparer, self).quote(ident, force=force)) class ODPSCompiler(compiler.SQLCompiler): def visit_column(self, *args, **kwargs): result = super(ODPSCompiler, self).visit_column(*args, **kwargs) dot_count = result.count(".") assert dot_count in (0, 1, 2), "Unexpected visit_column result {}".format( result ) if dot_count == 2: # we have something of the form schema.table.column # hive doesn't like the schema in front, so chop it out result = result[result.index(".") + 1 :] return result def visit_char_length_func(self, fn, **kw): return "length{}".format(self.function_argspec(fn, **kw)) def __unicode__(self): return to_text(self) class ODPSTypeCompiler(compiler.GenericTypeCompiler): def visit_INTEGER(self, type_): return "INT" def visit_NUMERIC(self, type_): return "DECIMAL" def visit_CHAR(self, type_): return "STRING" def visit_VARCHAR(self, type_): return "STRING" def visit_NCHAR(self, type_): return "STRING" def visit_TEXT(self, type_): return "STRING" def visit_CLOB(self, type_): return "STRING" def visit_BLOB(self, type_): return "BINARY" def visit_TIME(self, type_): return "TIMESTAMP" if hasattr(sqltypes.String, "RETURNS_UNICODE"): _return_unicode_str = sqltypes.String.RETURNS_UNICODE else: _return_unicode_str = True class ODPSPingError(BaseODPSError): pass class ODPSDialect(default.DefaultDialect): name = "odps" driver = "rest" preparer = ODPSIdentifierPreparer statement_compiler = ODPSCompiler supports_views = True supports_alter = True supports_pk_autoincrement = False supports_default_values = False supports_empty_insert = False supports_native_decimal = True supports_native_boolean = True supports_unicode_statements = True supports_unicode_binds = True returns_unicode_strings = _return_unicode_str description_encoding = None supports_multivalues_insert = True type_compiler = ODPSTypeCompiler supports_sane_rowcount = False supports_statement_cache = False _reused_odps = None default_schema_name = "default" @classmethod def dbapi(cls): from . import dbapi return dbapi def create_connect_args(self, url): url_string = str(url) project = url.host if project is None and options.default_project: project = options.default_project access_id = url.username secret_access_key = url.password logview_host = options.logview_host endpoint = None session_name = None sqa_type = False quota_name = None reuse_odps = False project_as_schema = False fallback_policy = "" cache_names = False cache_seconds = 24 * 3600 hints = {} if url.query: query = dict(url.query) if endpoint is None: endpoint = query.pop("endpoint", None) if logview_host is None: logview_host = query.pop("logview_host", query.pop("logview", None)) if session_name is None: session_name = query.pop("session", None) if quota_name is None: quota_name = query.pop("quota_name", None) if sqa_type is False: sqa_type = query.pop("interactive_mode", "false").lower() if sqa_type == "true": sqa_type = "v1" elif sqa_type == "false": sqa_type = False if reuse_odps is False: reuse_odps = query.pop("reuse_odps", "false").lower() != "false" if query.get("project_as_schema", None) is not None: project_as_schema = ( query.pop("project_as_schema", "false").lower() != "false" ) if fallback_policy == "": fallback_policy = query.pop("fallback_policy", "default") if cache_names is False: cache_names = query.pop("cache_names", "false").lower() != "false" cache_seconds = int(query.pop("cache_seconds", cache_seconds)) hints = query if endpoint is None: endpoint = options.endpoint or DEFAULT_ENDPOINT if session_name is None: session_name = PUBLIC_SESSION_NAME kwargs = { "access_id": access_id, "secret_access_key": secret_access_key, "project": project, "endpoint": endpoint, "session_name": session_name, "use_sqa": sqa_type, "fallback_policy": fallback_policy, "project_as_schema": project_as_schema, "hints": hints, } if quota_name is not None: kwargs["quota_name"] = quota_name if access_id is None: kwargs.pop("access_id", None) kwargs.pop("secret_access_key", None) kwargs["account"] = options.account for k, v in six.iteritems(kwargs): if v is None: raise ValueError( "{} should be provided to create connection, " "you can either specify in connection string as format: " '"odps://<access_id>:<access_key>@<project_name>", ' "or create an ODPS object and call `.to_global()` " "to set it to global".format(k) ) if logview_host is not None: kwargs["logview_host"] = logview_host if cache_names: _sqlalchemy_obj_list_cache[url_string] = ObjectCache(expire=cache_seconds) if reuse_odps: # the odps object can only be reused only if it will be identical if ( url_string in _sqlalchemy_global_reusable_odps and _sqlalchemy_global_reusable_odps.get(url_string) is not None ): kwargs["odps"] = _sqlalchemy_global_reusable_odps.get(url_string) kwargs["access_id"] = None kwargs["secret_access_key"] = None else: _sqlalchemy_global_reusable_odps[url_string] = ODPS( access_id=access_id, secret_access_key=secret_access_key, project=project, endpoint=endpoint, logview_host=logview_host, ) return [], kwargs def get_odps_from_url(self, url): _, kwargs = self.create_connect_args(url) if "odps" in kwargs: return kwargs["odps"] odps_kw = kwargs.copy() odps_kw.pop("session_name", None) odps_kw.pop("use_sqa", None) odps_kw.pop("fallback_policy", None) odps_kw.pop("hints", None) odps_kw.pop("project_as_schema", None) odps_kw["overwrite_global"] = False return ODPS(**odps_kw) @classmethod def get_list_cache(cls, url, key): url = str(url) if url not in _sqlalchemy_obj_list_cache: return None return _sqlalchemy_obj_list_cache[url].get(key) @classmethod def put_list_cache(cls, url, key, value): url = str(url) if url not in _sqlalchemy_obj_list_cache: return _sqlalchemy_obj_list_cache[url][key] = value def get_schema_names(self, connection, **kw): conn = self._get_dbapi_connection(connection) if getattr(conn, "_project_as_schema", False): fields = ["owner", "user", "group", "prefix"] if (conn.odps.project is None) or (kw.pop("listall", None) is not None): kwargs = {f: kw.get(f) for f in fields} return [proj.name for proj in conn.odps.list_projects(**kwargs)] else: return [conn.odps.project] else: try: return [schema.name for schema in conn.odps.list_schemas()] except: return ["default"] def has_table(self, connection, table_name, schema=None, **kw): conn = self._get_dbapi_connection(connection) schema_kw = self._get_schema_kw(connection, schema=schema) return conn.odps.exist_table(table_name, **schema_kw) @classmethod def _get_dbapi_connection(cls, sa_connection): if isinstance(sa_connection, Engine): sa_connection = sa_connection.connect() return sa_connection.connection.connection @classmethod def _get_schema_kw(cls, connection, schema=None): db_conn = cls._get_dbapi_connection(connection) if getattr(db_conn, "_project_as_schema", False): return dict(project=schema) else: return dict(schema=schema) def get_columns(self, connection, table_name, schema=None, **kw): conn = self._get_dbapi_connection(connection) schema_kw = self._get_schema_kw(connection, schema=schema) table = conn.odps.get_table(table_name, **schema_kw) result = [] try: for col in table.table_schema.columns: col_type = _odps_type_to_sqlalchemy_type[type(col.type)] result.append( { "name": col.name, "type": col_type, "nullable": True, "default": None, "comment": col.comment, } ) except NoSuchObject as e: # convert ODPSError to SQLAlchemy NoSuchTableError raise NoSuchTableError(str(e)) return result def get_foreign_keys(self, connection, table_name, schema=None, **kw): # ODPS has no support for foreign keys. return [] def get_pk_constraint(self, connection, table_name, schema=None, **kw): # ODPS has no support for primary keys. return [] def get_indexes(self, connection, table_name, schema=None, **kw): # ODPS has no support for indexes return [] def _iter_tables(self, connection, schema=None, types=None, **kw): cache_key = ("tables", schema, tuple(types)) cached = self.get_list_cache(connection.engine.url, cache_key) if cached is not None: return cached conn = self._get_dbapi_connection(connection) filter_ = getattr(test_setting, "get_tables_filter", None) if filter_ is None: filter_ = lambda x: True schema_kw = self._get_schema_kw(connection, schema=schema) if not types: it = conn.odps.list_tables(**schema_kw) else: its = [] for table_type in types: list_kw = schema_kw.copy() list_kw["type"] = table_type its.append(conn.odps.list_tables(**list_kw)) it = itertools.chain(*its) result = [t.name for t in it if filter_(t.name)] self.put_list_cache(connection.engine.url, cache_key, result) return result def get_table_names(self, connection, schema=None, **kw): return self._iter_tables( connection, schema=schema, types=[Table.Type.MANAGED_TABLE, Table.Type.EXTERNAL_TABLE], **kw ) def get_view_names(self, connection, schema=None, **kw): return self._iter_tables( connection, schema=schema, types=[Table.Type.VIRTUAL_VIEW, Table.Type.MATERIALIZED_VIEW], **kw ) def get_table_comment(self, connection, table_name, schema=None, **kw): conn = self._get_dbapi_connection(connection) schema_kw = self._get_schema_kw(connection, schema=schema) comment = conn.odps.get_table(table_name, **schema_kw).comment return {"text": comment} @classmethod def _is_stack_superset(cls, tb): try: cur_frame = tb.tb_frame while cur_frame is not None: if "superset" in cur_frame.f_code.co_filename: return True cur_frame = cur_frame.f_back return False except: # pragma: no cover return False def do_ping(self, dbapi_connection): """Stop raising RuntimeError when ping by Superset""" try: return super(ODPSDialect, self).do_ping(dbapi_connection) except InternalServerError: raise except BaseException as ex: _, _, tb = sys.exc_info() if not self._is_stack_superset(tb): raise new_err = ODPSPingError(ex.args[0]) for attr in ( "request_id", "instance_id", "code", "host_id", "endpoint", "tag", ): setattr(new_err, attr, getattr(ex, attr)) six.reraise(ODPSPingError, new_err, tb) def do_rollback(self, dbapi_connection): # No transactions for ODPS pass def _check_unicode_returns(self, connection, additional_tests=None): # We decode everything as UTF-8 return True def _check_unicode_description(self, connection): # We decode everything as UTF-8 return True