odps/dbapi.py (368 lines of code) (raw):
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import warnings
from .compat import enum
from .config import options
from .core import ODPS
from .errors import InstanceTypeNotSupported, NotSupportedError, ODPSError
from .models.session.v1 import PUBLIC_SESSION_NAME
from .utils import to_odps_scalar
# PEP 249 module globals
apilevel = "2.0"
threadsafety = 2 # Threads may share the module and connections.
paramstyle = "named" # Python extended format codes, e.g. ...WHERE name=%(name)s
class Error(Exception):
pass
class State(enum.Enum):
NONE = 0
RUNNING = 1
FINISHED = 2
def connect(*args, **kwargs):
"""Constructor for creating a connection to the database. See class :py:class:`Connection` for
arguments.
:returns: a :py:class:`Connection` object.
"""
return Connection(*args, **kwargs)
FALLBACK_POLICIES = {
"unsupported": ["ODPS-185"],
"upgrading": ["ODPS-182", "ODPS-184"],
"noresource": ["ODPS-183"],
"timeout": ["ODPS-186"],
"generic": ["ODPS-180"],
}
FALLBACK_POLICY_ALIASES = {
"default": ["unsupported", "upgrading", "noresource", "timeout"],
"all": ["unsupported", "upgrading", "noresource", "timeout", "generic"],
}
class Connection(object):
def __init__(
self,
access_id=None,
secret_access_key=None,
project=None,
endpoint=None,
account=None,
session_name=None,
odps=None,
hints=None,
quota_name=None,
**kw
):
if isinstance(access_id, ODPS):
access_id, odps = None, access_id
use_sqa = kw.pop("use_sqa", False)
if use_sqa == "v2":
self._sqa_type = "v2"
else:
self._sqa_type = None if not use_sqa else "v1"
self._quota_name = quota_name
self._fallback_policy = kw.pop("fallback_policy", "")
self._project_as_schema = kw.pop(
"project_as_schema", options.sqlalchemy.project_as_schema
)
if odps is None:
self._odps = ODPS(
access_id=access_id,
secret_access_key=secret_access_key,
project=project,
endpoint=endpoint,
account=account,
quota_name=quota_name,
**kw
)
else:
if access_id is not None:
raise ValueError("Either access_id or odps can be specified")
self._odps = odps
try:
if self._project_as_schema is None:
self._project_as_schema = not self._odps.is_schema_namespace_enabled()
except: # pragma: no cover
pass
self._session_name = PUBLIC_SESSION_NAME
if session_name is not None:
self._session_name = session_name
self._hints = hints
@property
def odps(self):
return self._odps
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def cursor(self, *args, **kwargs):
"""Return a new :py:class:`Cursor` object using the connection."""
return Cursor(
self,
*args,
sqa_type=self._sqa_type,
fallback_policy=self._fallback_policy,
hints=self._hints,
quota_name=self._quota_name,
**kwargs
)
def close(self):
# there is no long polling for ODPS
# do nothing for close
pass
def commit(self):
# ODPS does not support transactions
# do nothing for commit
pass
def rollback(self):
raise NotSupportedError("ODPS does not have transactions")
default_arraysize = 1000
class Cursor(object):
def __init__(
self,
connection,
arraysize=default_arraysize,
sqa_type=None,
fallback_policy="",
hints=None,
quota_name=None,
**kwargs
):
self._connection = connection
self._arraysize = arraysize
self._reset_state()
self.lastrowid = None
self._sqa_type = sqa_type
self._fallback_policy = []
self._hints = hints
self._quota_name = quota_name
fallback_policies = map(lambda x: x.strip(), fallback_policy.split(","))
for policy in fallback_policies:
if policy in FALLBACK_POLICY_ALIASES:
self._fallback_policy.extend(FALLBACK_POLICY_ALIASES[policy])
else:
self._fallback_policy.append(policy)
def _reset_state(self):
self._state = State.NONE
self._description = None
# odps instance and download session
self._instance = None
self._download_session = None
@property
def arraysize(self):
return self._arraysize
@property
def instance(self):
return self._instance
@property
def connection(self):
return self._connection
@arraysize.setter
def arraysize(self, value):
try:
self._arraysize = max(int(value), default_arraysize)
except TypeError:
warnings.warn(
"arraysize has to be a integer, got {}, "
"will set default value 1000".format(value)
)
self._arraysize = default_arraysize
@property
def description(self):
"""This read-only attribute is a sequence of 7-item sequences.
Each of these sequences contains information describing one result column:
- name
- type_code
- display_size (None in current implementation)
- internal_size (None in current implementation)
- precision (None in current implementation)
- scale (None in current implementation)
- null_ok (always True in current implementation)
This attribute will be ``None`` for operations that do not return rows or if the cursor has
not had an operation invoked via the :py:meth:`execute` method yet.
The ``type_code`` can be interpreted by comparing it to the Type Objects specified in the
section below.
"""
if self._instance is None:
return
if self._description is None:
self._check_download_session()
self._description = []
if self._download_session is not None:
for col in self._download_session.schema.columns:
self._description.append(
(col.name, col.type.name, None, None, None, None, True)
)
else:
self._description.append(
("_c0", "string", None, None, None, None, True)
)
return self._description
@staticmethod
def _find_placeholders(sql, placeholder_type):
placeholders = []
in_string = False
string_char = None # single or double quote
i = 0
while i < len(sql):
c = sql[i]
if in_string:
if c == string_char:
in_string = False
string_char = None
elif c == "\\":
i += 1
else:
if c in ('"', "'"):
in_string = True
string_char = c
elif c == ":" and placeholder_type == "named":
# check start char (should be alphabetical or underline)
if i + 1 < len(sql) and (sql[i + 1].isalpha() or sql[i + 1] == "_"):
j = i + 1
while j < len(sql) and (sql[j].isalnum() or sql[j] == "_"):
j += 1
key = sql[i + 1 : j]
placeholders.append((i, key))
i = j
continue
elif c == "?" and placeholder_type == "positional":
placeholders.append((i, None))
i += 1
return placeholders
@classmethod
def _replace_sql_parameters(cls, sql, parameters):
if isinstance(parameters, dict):
placeholders = cls._find_placeholders(sql, "named")
# check if all parameters provided
for pos, key in placeholders:
if key not in parameters:
raise KeyError("Missing parameter '%s'" % key)
# collect replacements
replacements = []
for pos, key in placeholders:
replacement = to_odps_scalar(parameters[key])
placeholder_str = ":" + key
start = pos
end = pos + len(placeholder_str)
replacements.append((start, end, replacement))
# replace in reversed order to avoid position shift
replacements.sort(key=lambda x: x[0], reverse=True)
new_sql = list(sql)
for start, end, repl in replacements:
new_sql[start:end] = list(repl)
return "".join(new_sql)
elif isinstance(parameters, (list, tuple)):
placeholders = cls._find_placeholders(sql, "positional")
num_placeholders = len(placeholders)
num_params = len(parameters)
if num_placeholders != num_params:
raise ValueError(
"Expected %d parameters, got %d" % (num_placeholders, num_params)
)
# split sql statement into parts
positions = [pos for pos, _ in placeholders]
parts = []
prev = 0
for pos in positions:
parts.append(sql[prev:pos])
prev = pos + 1
parts.append(sql[prev:])
# merge all stuff together
result = []
for i in range(len(parts) - 1):
result.append(parts[i])
result.append(to_odps_scalar(parameters[i]))
result.append(parts[-1])
return "".join(result)
else:
raise TypeError("Parameters must be a dictionary or tuple")
def execute(self, operation, parameters=None, **kwargs):
"""Prepare and execute a database operation (query or command).
Parameters may be provided as sequence or mapping and will be bound to variables in the
operation. Variables are specified in a database-specific notation (see the module's
``paramstyle`` attribute for details).
Return values are not defined.
"""
for k in ["async", "async_"]:
if k in kwargs:
async_ = kwargs[k]
break
else:
async_ = False
# prepare statement
if parameters is None:
sql = operation
else:
sql = self._replace_sql_parameters(operation, parameters)
self._reset_state()
odps = self._connection.odps
run_sql = odps.execute_sql
if self._sqa_type == "v2":
run_sql = functools.partial(self._run_sqa_with_fallback, use_mcqa_v2=True)
elif self._sqa_type == "v1":
run_sql = self._run_sqa_with_fallback
if async_:
run_sql = odps.run_sql
hints = dict(self._hints or {})
hints.update(kwargs.get("hints") or {})
self._instance = run_sql(sql, hints=hints, quota_name=self._quota_name)
return self
def executemany(self, operation, seq_of_parameters):
for parameter in seq_of_parameters:
self.execute(operation, parameter)
return self
def executescript(self, sql_script, **kwargs):
hints = dict(kwargs.get("hints") or {})
hints["odps.sql.submit.mode"] = "script"
return self.execute(sql_script, hints=hints)
def _sqa_error_should_fallback(self, err_str):
if "ODPS-18" not in err_str:
return False
for fallback_case in self._fallback_policy:
fallback_error = FALLBACK_POLICIES.get(fallback_case, None)
if fallback_error is None:
continue
for error_code in fallback_error:
if error_code in err_str:
return True
return False
def _run_sqa_with_fallback(self, sql, **kw):
odps = self._connection.odps
session_name = self._connection._session_name
quota_name = self._connection._odps.quota_name
use_v2 = kw.get("use_mcqa_v2", False)
inst = None
while True:
try:
if inst is None:
if use_v2:
inst = odps.run_sql_interactive(
sql, quota_name=quota_name, use_mcqa_v2=use_v2
)
else:
inst = odps.run_sql_interactive(sql, service_name=session_name)
else:
inst.wait_for_success(interval=0.5)
rd = inst.open_reader(tunnel=True, limit=False)
if not rd:
raise ODPSError("failed to create direct download")
rd.schema # will check if task is ok
self._download_session = rd
return inst
except ODPSError as e:
if self._sqa_error_should_fallback(str(e)):
return odps.execute_sql(sql)
elif "OdpsTaskTimeout" in str(e):
# tunnel failed to wait data cache result. fallback to normal wait.
pass
else:
raise e
def cancel(self):
if self._instance is not None:
self._instance.stop()
def close(self):
self._reset_state()
def _check_download_session(self):
if not self._download_session and self._instance:
try:
self._download_session = self._instance.open_reader(
tunnel=True, limit=False
)
except InstanceTypeNotSupported:
# not select, cannot create session
self._download_session = None
def _fetch_non_select(self):
# not select
# just return reader.raw
with self._instance.open_reader() as reader:
return [(reader.raw,)]
def _fetch(self, size):
self._check_download_session()
if self._download_session is None:
return self._fetch_non_select()
results = []
i = 0
while size == -1 or i < size:
try:
results.append(next(self._download_session).values)
except StopIteration:
break
i += 1
return results
def __iter__(self):
while True:
res = self.fetchone()
if res is not None:
yield res
if res is None or self._download_session is None:
break
def next(self):
res = self.fetchone()
if res is not None:
yield res
else:
raise StopIteration
def fetchone(self):
self._check_download_session()
results = self._fetch(1)
if len(results) == 1:
return results[0]
def fetchmany(self, size=None):
if size is None:
size = self._arraysize
return self._fetch(size)
def fetchall(self):
return self._fetch(-1)