odps/models/schemas.py (167 lines of code) (raw):
# Copyright 1999-2024 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from .. import serializers
from ..compat import six
from ..errors import (
InternalServerError,
InvalidParameter,
MethodNotAllowed,
NoSuchObject,
)
from ..utils import with_wait_argument
from .core import Iterable
from .schema import Schema
logger = logging.getLogger(__name__)
_project_has_schema_api = dict()
def with_schema_api_fallback(fallback_fun, is_iter=False):
def decorator(fun):
@six.wraps(fun)
def wrapper(self, *args, **kwargs):
key = (self.parent.odps.endpoint, self.parent.name)
kw = kwargs.copy()
try:
self._check_schema_api()
result = fun(self, *args, **kw)
_project_has_schema_api[key] = True
return result
except (MethodNotAllowed, InvalidParameter):
_project_has_schema_api[key] = False
return fallback_fun(self, *args, **kwargs)
@six.wraps(fun)
def iter_wrapper(self, *args, **kwargs):
key = (self.parent.odps.endpoint, self.parent.name)
kw = kwargs.copy()
try:
self._check_schema_api()
for item in fun(self, *args, **kw):
yield item
_project_has_schema_api[key] = True
return
except (MethodNotAllowed, InvalidParameter):
if _project_has_schema_api.get(key):
# in case duplicated items are iterated
raise
_project_has_schema_api[key] = False
for item in fallback_fun(self, *args, **kwargs):
yield item
return iter_wrapper if is_iter else wrapper
return decorator
class Schemas(Iterable):
marker = serializers.XMLNodeField("Marker")
max_items = serializers.XMLNodeField("MaxItems")
schemas = serializers.XMLNodesReferencesField(Schema, "Schema")
def __iter__(self):
return self.iterate()
def resource(self, client=None, endpoint=None):
return self.parent.resource(client, endpoint=endpoint)
def _check_schema_api(self):
key = (self.parent.odps.endpoint, self.parent.name)
if not _project_has_schema_api.get(key, True):
raise MethodNotAllowed("Schema API not supported")
def _iterate_legacy(self, name=None, owner=None):
if name is not None or owner is not None:
raise ValueError(
"Iterating schemas with name or owner not supported on current service"
)
inst = self.parent.odps.execute_sql("SHOW SCHEMAS IN %s" % self.parent.name)
schema_names = (
inst.get_task_results().get("AnonymousSQLTask").strip().split("\n")
)
for schema_name in schema_names:
yield Schema(name=schema_name, parent=self, client=self._client)
@with_schema_api_fallback(fallback_fun=_iterate_legacy, is_iter=True)
def iterate(self, name=None, owner=None):
params = {"expectmarker": "true"}
if name is not None:
params["name"] = name
if owner is not None:
params["owner"] = owner
schema_name = self._get_schema_name()
if schema_name is not None:
params["curr_schema"] = schema_name
def _it():
last_marker = params.get("marker")
if "marker" in params and (last_marker is None or len(last_marker) == 0):
return
url = self.resource() + "/schemas"
resp = self._client.get(url, params=params)
r = Schemas.parse(self._client, resp, obj=self)
params["marker"] = r.marker
return r.schemas
while True:
schemas = _it()
if schemas is None:
break
for schema in schemas:
yield schema
@with_wait_argument
def _create_legacy(self, obj=None, async_=False, **kwargs):
schema_name = kwargs.pop("schema_name", obj)
if isinstance(obj, Schema):
schema_name = obj.name
inst = self.parent.odps.run_sql(
"CREATE SCHEMA %s.%s" % (self.parent.name, schema_name)
)
if not async_:
inst.wait_for_success()
return Schema(name=schema_name, parent=self, client=self._client)
return inst
@with_schema_api_fallback(fallback_fun=_create_legacy)
def create(self, obj=None, **kwargs):
kwargs.pop("async_", None)
kwargs.pop("wait", None)
if isinstance(obj, six.string_types):
kwargs["name"] = obj
obj = None
schema = obj or Schema(parent=self, client=self._client, **kwargs)
if schema.parent is None:
schema._parent = self
if schema._client is None:
schema._client = self._client
headers = {"Content-Type": "application/xml"}
data = schema.serialize()
resource = self.resource() + "/schemas"
self._client.post(resource, data, headers=headers)
return schema
@with_wait_argument
def _delete_legacy(self, schema_name, async_=False):
if isinstance(schema_name, Schema):
schema_name = schema_name.name
inst = self.parent.odps.run_sql(
"DROP SCHEMA %s.%s" % (self.parent.name, schema_name)
)
if not async_:
return inst.wait_for_success()
return inst
@with_schema_api_fallback(fallback_fun=_delete_legacy)
def delete(self, schema_name, async_=False):
if isinstance(schema_name, Schema):
schema_name = schema_name.name
resource = self.resource() + "/schemas/" + schema_name
self._client.delete(resource)
def _get(self, item):
if isinstance(item, Schema):
return item
return Schema(name=item, parent=self, client=self._client)
def _contains_legacy(self, item):
try:
next(self._get(item).functions)
except StopIteration:
pass
except NoSuchObject:
return False
except InvalidParameter as ex:
if "NoSuchObjectException" in str(ex):
return False
raise
except InternalServerError as ex:
if "invalid schema name" in str(ex).lower():
return False
raise
return True
@with_schema_api_fallback(fallback_fun=_contains_legacy)
def __contains__(self, item):
schema = self._get(item)
try:
schema.reload()
return True
except NoSuchObject:
return False