core/maxframe/session.py (814 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import concurrent.futures
import logging
import random
import string
import threading
import warnings
from abc import ABC, ABCMeta, abstractmethod
from concurrent.futures import Future as SyncFuture
from dataclasses import dataclass
from functools import wraps
from typing import Callable, Coroutine, Dict, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse
from weakref import ref
from odps import ODPS
from maxframe.core import TileableType
from maxframe.lib.aio import Isolation, get_isolation, new_isolation, stop_isolation
from maxframe.typing_ import ClientType
from maxframe.utils import classproperty, implements, relay_future
from .config import options
logger = logging.getLogger(__name__)
@dataclass
class Progress:
value: float = 0.0
@dataclass
class Profiling:
result: dict = None
class ExecutionInfo:
def __init__(
self,
aio_task: asyncio.Task,
progress: Progress,
profiling: Profiling,
loop: asyncio.AbstractEventLoop,
to_execute_tileables: List[TileableType],
):
self._aio_task = aio_task
self._progress = progress
self._profiling = profiling
self._loop = loop
self._to_execute_tileables = [ref(t) for t in to_execute_tileables]
self._future_local = threading.local()
def _ensure_future(self):
try:
self._future_local.future
except AttributeError:
async def wait():
return await self._aio_task
self._future_local.future = fut = asyncio.run_coroutine_threadsafe(
wait(), self._loop
)
self._future_local.aio_future = asyncio.wrap_future(fut)
@property
def loop(self):
return self._loop
@property
def aio_task(self):
return self._aio_task
def progress(self) -> float:
return self._progress.value
@property
def to_execute_tileables(self) -> List[TileableType]:
return [t() for t in self._to_execute_tileables]
def profiling_result(self) -> dict:
return self._profiling.result
def result(self, timeout=None):
self._ensure_future()
return self._future_local.future.result(timeout=timeout)
def cancel(self):
self._aio_task.cancel()
def __getattr__(self, attr):
self._ensure_future()
return getattr(self._future_local.aio_future, attr)
def __await__(self):
self._ensure_future()
return self._future_local.aio_future.__await__()
def get_future(self):
self._ensure_future()
return self._future_local.aio_future
warning_msg = """
No session found, local session \
will be created in background, \
it may take a while before execution. \
If you want to new a local session by yourself, \
run code below:
```
import maxframe
maxframe.new_session()
```
"""
class AbstractSession(ABC):
name = None
_default = None
_lock = threading.Lock()
def __init__(self, address: str, session_id: str):
self._address = address
self._session_id = session_id
self._closed = False
@property
def address(self):
return self._address
@property
def session_id(self):
return self._session_id
@property
def closed(self) -> bool:
return self._closed
def __eq__(self, other):
return (
isinstance(other, AbstractSession)
and self._address == other.address
and self._session_id == other.session_id
)
def __hash__(self):
return hash((AbstractSession, self._address, self._session_id))
def as_default(self) -> "AbstractSession":
"""
Mark current session as default session.
"""
AbstractSession._default = self
return self
@classmethod
def reset_default(cls):
AbstractSession._default = None
@classproperty
def default(self):
return AbstractSession._default
class AbstractAsyncSession(AbstractSession, metaclass=ABCMeta):
@classmethod
@abstractmethod
async def init(
cls, address: str, session_id: str, new: bool = True, **kwargs
) -> "AbstractSession":
"""
Init a new session.
Parameters
----------
address : str
Address.
session_id : str
Session ID.
new : bool
New a session.
kwargs
Returns
-------
session
"""
async def destroy(self):
"""
Destroy a session.
"""
self.reset_default()
self._closed = True
@abstractmethod
async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
"""
Execute tileables.
Parameters
----------
tileables
Tileables.
kwargs
"""
@abstractmethod
async def fetch(self, *tileables, **kwargs) -> list:
"""
Fetch tileables' data.
Parameters
----------
tileables
Tileables.
Returns
-------
data
"""
@abstractmethod
async def decref(self, *tileables_keys):
"""
Decref tileables.
Parameters
----------
tileables_keys : list
Tileables' keys
"""
@abstractmethod
async def _get_ref_counts(self) -> Dict[str, int]:
"""
Get all ref counts
Returns
-------
ref_counts
"""
@abstractmethod
async def fetch_tileable_op_logs(
self,
tileable_op_key: str,
offsets: Union[Dict[str, List[int]], str, int],
sizes: Union[Dict[str, List[int]], str, int],
) -> Dict:
"""
Fetch logs given tileable op key.
Parameters
----------
tileable_op_key : str
Tileable op key.
offsets
Chunk op key to offsets.
sizes
Chunk op key to sizes.
Returns
-------
chunk_key_to_logs
"""
@abstractmethod
async def get_total_n_cpu(self):
"""
Get number of cluster cpus.
Returns
-------
number_of_cpu: int
"""
@abstractmethod
async def get_cluster_versions(self) -> List[str]:
"""
Get versions used in current MaxFrame cluster
Returns
-------
version_list : list
List of versions
"""
@abstractmethod
async def get_web_endpoint(self) -> Optional[str]:
"""
Get web endpoint of current session
Returns
-------
web_endpoint : str
web endpoint
"""
@abstractmethod
async def create_remote_object(
self, session_id: str, name: str, object_cls, *args, **kwargs
):
"""
Create remote object
Parameters
----------
session_id : str
Session ID.
name : str
object_cls
args
kwargs
Returns
-------
actor_ref
"""
@abstractmethod
async def get_remote_object(self, session_id: str, name: str):
"""
Get remote object.
Parameters
----------
session_id : str
Session ID.
name : str
Returns
-------
actor_ref
"""
@abstractmethod
async def destroy_remote_object(self, session_id: str, name: str):
"""
Destroy remote object.
Parameters
----------
session_id : str
Session ID.
name : str
"""
async def stop_server(self):
"""
Stop server.
"""
@abstractmethod
async def get_logview_address(self, hours=None) -> Optional[str]:
"""
Get Logview address
Returns
-------
Logview address
"""
def close(self):
asyncio.run(self.destroy())
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
class AbstractSyncSession(AbstractSession, metaclass=ABCMeta):
@classmethod
@abstractmethod
def init(
cls,
address: str,
session_id: str,
backend: str = "maxframe",
new: bool = True,
**kwargs,
) -> "AbstractSession":
"""
Init a new session.
Parameters
----------
address : str
Address.
session_id : str
Session ID.
backend : str
Backend.
new : bool
New a session.
kwargs
Returns
-------
session
"""
@abstractmethod
def execute(
self, tileable, *tileables, show_progress: Union[bool, str] = None, **kwargs
) -> Union[List[TileableType], TileableType, ExecutionInfo]:
"""
Execute tileables.
Parameters
----------
tileable
Tileable.
tileables
Tileables.
show_progress
If show progress.
kwargs
Returns
-------
result
"""
@abstractmethod
def fetch(self, *tileables, **kwargs) -> list:
"""
Fetch tileables.
Parameters
----------
tileables
Tileables.
kwargs
Returns
-------
fetched_data : list
"""
@abstractmethod
def fetch_infos(self, *tileables, fields, **kwargs) -> list:
"""
Fetch infos of tileables.
Parameters
----------
tileables
Tileables.
fields
List of fields
kwargs
Returns
-------
fetched_infos : list
"""
@abstractmethod
def decref(self, *tileables_keys):
"""
Decref tileables.
Parameters
----------
tileables_keys : list
Tileables' keys
"""
@abstractmethod
def _get_ref_counts(self) -> Dict[str, int]:
"""
Get all ref counts
Returns
-------
ref_counts
"""
@abstractmethod
def fetch_tileable_op_logs(
self,
tileable_op_key: str,
offsets: Union[Dict[str, List[int]], str, int],
sizes: Union[Dict[str, List[int]], str, int],
) -> Dict:
"""
Fetch logs given tileable op key.
Parameters
----------
tileable_op_key : str
Tileable op key.
offsets
Chunk op key to offsets.
sizes
Chunk op key to sizes.
Returns
-------
chunk_key_to_logs
"""
@abstractmethod
def get_total_n_cpu(self):
"""
Get number of cluster cpus.
Returns
-------
number_of_cpu: int
"""
@abstractmethod
def get_cluster_versions(self) -> List[str]:
"""
Get versions used in current MaxFrame cluster
Returns
-------
version_list : list
List of versions
"""
@abstractmethod
def get_web_endpoint(self) -> Optional[str]:
"""
Get web endpoint of current session
Returns
-------
web_endpoint : str
web endpoint
"""
def fetch_log(
self,
tileables: List[TileableType],
offsets: List[int] = None,
sizes: List[int] = None,
):
from .core.custom_log import fetch
return fetch(tileables, self, offsets=offsets, sizes=sizes)
@abstractmethod
def get_logview_address(self, hours=None) -> Optional[str]:
"""
Get logview address
Returns
-------
logview address
"""
def _delegate_to_isolated_session(func: Union[Callable, Coroutine]):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def inner(session: "AsyncSession", *args, **kwargs):
coro = getattr(session._isolated_session, func.__name__)(*args, **kwargs)
fut = asyncio.run_coroutine_threadsafe(coro, session._loop)
return await asyncio.wrap_future(fut)
else:
@wraps(func)
def inner(session: "SyncSession", *args, **kwargs):
coro = getattr(session._isolated_session, func.__name__)(*args, **kwargs)
fut = asyncio.run_coroutine_threadsafe(coro, session._loop)
return fut.result()
return inner
_schemes_to_isolated_session_cls: Dict[
Optional[str], Type["IsolatedAsyncSession"]
] = dict()
class IsolatedAsyncSession(AbstractAsyncSession):
"""
Abstract class of isolated session which can be registered
to different schemes
"""
schemes = None
@classmethod
def register_schemes(cls, overwrite: bool = True):
assert isinstance(cls.schemes, list)
for scheme in cls.schemes:
if overwrite or scheme not in _schemes_to_isolated_session_cls:
_schemes_to_isolated_session_cls[scheme] = cls
def _get_isolated_session_cls(address: str) -> Type[IsolatedAsyncSession]:
if ":/" not in address:
url_scheme = None
else:
url_scheme = urlparse(address).scheme or None
scheme_cls = _schemes_to_isolated_session_cls.get(url_scheme)
if scheme_cls is None:
raise ValueError(f"Address scheme {url_scheme} not supported")
return scheme_cls
class AsyncSession(AbstractAsyncSession):
def __init__(
self,
address: str,
session_id: str,
isolated_session: IsolatedAsyncSession,
isolation: Isolation,
):
super().__init__(address, session_id)
self._isolated_session = _get_isolated_session(isolated_session)
self._isolation = isolation
self._loop = isolation.loop
@classmethod
def from_isolated_session(
cls, isolated_session: IsolatedAsyncSession
) -> "AsyncSession":
return cls(
isolated_session.address,
isolated_session.session_id,
isolated_session,
get_isolation(),
)
@property
def client(self):
return self._isolated_session.client
@client.setter
def client(self, client: ClientType):
self._isolated_session.client = client
@classmethod
@implements(AbstractAsyncSession.init)
async def init(
cls,
address: str,
session_id: str,
backend: str = "maxframe",
new: bool = True,
**kwargs,
) -> "AbstractSession":
isolation = ensure_isolation_created(kwargs)
coro = _get_isolated_session_cls(address).init(
address, session_id, backend, new=new, **kwargs
)
fut = asyncio.run_coroutine_threadsafe(coro, isolation.loop)
isolated_session = await asyncio.wrap_future(fut)
return AsyncSession(address, session_id, isolated_session, isolation)
def as_default(self) -> AbstractSession:
AbstractSession._default = self._isolated_session
return self
@implements(AbstractAsyncSession.destroy)
async def destroy(self):
coro = self._isolated_session.destroy()
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro, self._loop))
self.reset_default()
@implements(AbstractAsyncSession.execute)
@_delegate_to_isolated_session
async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
pass # pragma: no cover
@implements(AbstractAsyncSession.fetch)
async def fetch(self, *tileables, **kwargs) -> list:
coro = _fetch(*tileables, session=self._isolated_session, **kwargs)
return await asyncio.wrap_future(
asyncio.run_coroutine_threadsafe(coro, self._loop)
)
@implements(AbstractAsyncSession._get_ref_counts)
@_delegate_to_isolated_session
async def _get_ref_counts(self) -> Dict[str, int]:
pass # pragma: no cover
@implements(AbstractAsyncSession.fetch_tileable_op_logs)
@_delegate_to_isolated_session
async def fetch_tileable_op_logs(
self,
tileable_op_key: str,
offsets: Union[Dict[str, List[int]], str, int],
sizes: Union[Dict[str, List[int]], str, int],
) -> Dict:
pass # pragma: no cover
@implements(AbstractAsyncSession.get_total_n_cpu)
@_delegate_to_isolated_session
async def get_total_n_cpu(self):
pass # pragma: no cover
@implements(AbstractAsyncSession.get_cluster_versions)
@_delegate_to_isolated_session
async def get_cluster_versions(self) -> List[str]:
pass # pragma: no cover
@implements(AbstractAsyncSession.create_remote_object)
@_delegate_to_isolated_session
async def create_remote_object(
self, session_id: str, name: str, object_cls, *args, **kwargs
):
pass # pragma: no cover
@implements(AbstractAsyncSession.get_remote_object)
@_delegate_to_isolated_session
async def get_remote_object(self, session_id: str, name: str):
pass # pragma: no cover
@implements(AbstractAsyncSession.destroy_remote_object)
@_delegate_to_isolated_session
async def destroy_remote_object(self, session_id: str, name: str):
pass # pragma: no cover
@implements(AbstractAsyncSession.get_web_endpoint)
@_delegate_to_isolated_session
async def get_web_endpoint(self) -> Optional[str]:
pass # pragma: no cover
@implements(AbstractAsyncSession.stop_server)
async def stop_server(self):
coro = self._isolated_session.stop_server()
await asyncio.wrap_future(asyncio.run_coroutine_threadsafe(coro, self._loop))
stop_isolation()
@implements(AbstractAsyncSession.get_logview_address)
@_delegate_to_isolated_session
async def get_logview_address(self, hours=None) -> Optional[str]:
pass # pragma: no cover
class ProgressBar:
def __init__(self, show_progress):
if not show_progress:
self.progress_bar = None
else:
try:
from tqdm.auto import tqdm
except ImportError:
if show_progress != "auto": # pragma: no cover
raise ImportError("tqdm is required to show progress")
else:
self.progress_bar = None
else:
self.progress_bar = tqdm(
total=100,
bar_format="{l_bar}{bar}| {n:6.2f}/{total_fmt} "
"[{elapsed}<{remaining}, {rate_fmt}{postfix}]",
)
self.last_progress: float = 0.0
@property
def show_progress(self) -> bool:
return self.progress_bar is not None
def __enter__(self):
self.progress_bar.__enter__()
def __exit__(self, *_):
self.progress_bar.__exit__(*_)
def update(self, progress: float):
progress = min(progress, 100)
last_progress = self.last_progress
if self.progress_bar:
incr = max(progress - last_progress, 0)
self.progress_bar.update(incr)
self.last_progress = max(last_progress, progress)
class SyncSession(AbstractSyncSession):
_execution_pool = concurrent.futures.ThreadPoolExecutor(1)
def __init__(
self,
address: str,
session_id: str,
isolated_session: IsolatedAsyncSession,
isolation: Isolation,
):
super().__init__(address, session_id)
self._isolated_session = _get_isolated_session(isolated_session)
self._isolation = isolation
self._loop = isolation.loop
@classmethod
def from_isolated_session(
cls, isolated_session: IsolatedAsyncSession
) -> "SyncSession":
return cls(
isolated_session.address,
isolated_session.session_id,
isolated_session,
get_isolation(),
)
@classmethod
def init(
cls,
address: str,
session_id: str,
backend: str = "maxframe",
new: bool = True,
**kwargs,
) -> "AbstractSession":
isolation = ensure_isolation_created(kwargs)
coro = _get_isolated_session_cls(address).init(
address, session_id, backend, new=new, **kwargs
)
fut = asyncio.run_coroutine_threadsafe(coro, isolation.loop)
isolated_session = fut.result()
return SyncSession(address, session_id, isolated_session, isolation)
def as_default(self) -> AbstractSession:
AbstractSession._default = self._isolated_session
return self
@property
def _session(self):
return self._isolated_session
@property
def session_id(self):
try:
return self._session.session_id or self._session_id
except AttributeError:
return self._session_id
def _new_cancel_event(self):
async def new_event():
return asyncio.Event()
return asyncio.run_coroutine_threadsafe(new_event(), self._loop).result()
@implements(AbstractSyncSession.execute)
def execute(
self,
tileable,
*tileables,
show_progress: Union[bool, str] = None,
warn_duplicated_execution: bool = None,
**kwargs,
) -> Union[List[TileableType], TileableType, ExecutionInfo]:
wait = kwargs.get("wait", True)
# add an intermediate future for cancel tests
result_future = kwargs.pop("result_future", None) or SyncFuture()
if show_progress is None:
show_progress = options.show_progress
if warn_duplicated_execution is None:
warn_duplicated_execution = options.warn_duplicated_execution
to_execute_tileables = []
for t in (tileable,) + tileables:
to_execute_tileables.extend(t.op.outputs)
cancelled = kwargs.get("cancelled")
if cancelled is None:
cancelled = kwargs["cancelled"] = self._new_cancel_event()
coro = _execute(
*set(to_execute_tileables),
session=self._isolated_session,
show_progress=show_progress,
warn_duplicated_execution=warn_duplicated_execution,
**kwargs,
)
fut = asyncio.run_coroutine_threadsafe(coro, self._loop)
relay_future(result_future, fut)
try:
execution_info: ExecutionInfo = result_future.result(
timeout=self._isolated_session.timeout
)
except KeyboardInterrupt: # pragma: no cover
logger.warning("Cancelling running task")
cancelled.set()
fut.result()
logger.warning("Cancel finished")
if wait:
return tileable if len(tileables) == 0 else [tileable] + list(tileables)
else:
aio_task = execution_info.aio_task
async def run():
await aio_task
return tileable if len(tileables) == 0 else [tileable] + list(tileables)
async def driver():
return asyncio.create_task(run())
new_aio_task = asyncio.run_coroutine_threadsafe(
driver(), execution_info.loop
).result()
new_execution_info = ExecutionInfo(
new_aio_task,
execution_info._progress,
execution_info._profiling,
execution_info.loop,
to_execute_tileables,
)
return new_execution_info
@implements(AbstractSyncSession.fetch)
def fetch(self, *tileables, **kwargs) -> list:
coro = _fetch(*tileables, session=self._isolated_session, **kwargs)
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
@implements(AbstractSyncSession.fetch_infos)
def fetch_infos(self, *tileables, fields, **kwargs) -> list:
coro = _fetch_infos(
*tileables, fields=fields, session=self._isolated_session, **kwargs
)
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
@implements(AbstractSyncSession.decref)
def decref(self, *tileable_keys):
coro = _decref(*tileable_keys, session=self._isolated_session)
return asyncio.run_coroutine_threadsafe(coro, self._loop).result()
@implements(AbstractSyncSession._get_ref_counts)
@_delegate_to_isolated_session
def _get_ref_counts(self) -> Dict[str, int]:
pass # pragma: no cover
@implements(AbstractSyncSession.fetch_tileable_op_logs)
@_delegate_to_isolated_session
def fetch_tileable_op_logs(
self,
tileable_op_key: str,
offsets: Union[Dict[str, List[int]], str, int],
sizes: Union[Dict[str, List[int]], str, int],
) -> Dict:
pass # pragma: no cover
@implements(AbstractSyncSession.get_total_n_cpu)
@_delegate_to_isolated_session
def get_total_n_cpu(self):
pass # pragma: no cover
@implements(AbstractSyncSession.get_web_endpoint)
@_delegate_to_isolated_session
def get_web_endpoint(self) -> Optional[str]:
pass # pragma: no cover
@implements(AbstractSyncSession.get_cluster_versions)
@_delegate_to_isolated_session
def get_cluster_versions(self) -> List[str]:
pass # pragma: no cover
@implements(AbstractSyncSession.get_logview_address)
@_delegate_to_isolated_session
def get_logview_address(self, hours=None) -> Optional[str]:
pass # pragma: no cover
def destroy(self):
coro = self._isolated_session.destroy()
asyncio.run_coroutine_threadsafe(coro, self._loop).result()
self.reset_default()
def stop_server(self, isolation=True):
try:
coro = self._isolated_session.stop_server()
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
future.result(timeout=5)
finally:
self.reset_default()
if isolation:
stop_isolation()
def close(self):
self.destroy()
def __enter__(self):
return self
def __exit__(self, *_):
self.close()
async def _execute_with_progress(
execution_info: ExecutionInfo,
progress_bar: ProgressBar,
progress_update_interval: Union[int, float],
cancelled: asyncio.Event,
):
with progress_bar:
while not cancelled.is_set():
done, _pending = await asyncio.wait(
[execution_info.get_future()], timeout=progress_update_interval
)
if not done:
if not cancelled.is_set() and execution_info.progress() is not None:
progress_bar.update(execution_info.progress() * 100)
else:
# done
if not cancelled.is_set():
progress_bar.update(100)
break
async def _execute(
*tileables: Tuple[TileableType, ...],
session: IsolatedAsyncSession = None,
wait: bool = True,
show_progress: Union[bool, str] = "auto",
progress_update_interval: Union[int, float] = 1,
cancelled: asyncio.Event = None,
**kwargs,
):
execution_info = await session.execute(*tileables, **kwargs)
def _attach_session(future: asyncio.Future):
if future.exception() is None:
for t in execution_info.to_execute_tileables:
t._attach_session(session)
execution_info.add_done_callback(_attach_session)
cancelled = cancelled or asyncio.Event()
if wait:
progress_bar = ProgressBar(show_progress)
if progress_bar.show_progress:
await _execute_with_progress(
execution_info, progress_bar, progress_update_interval, cancelled
)
else:
exec_task = asyncio.ensure_future(execution_info)
cancel_task = asyncio.ensure_future(cancelled.wait())
await asyncio.wait(
[exec_task, cancel_task], return_when=asyncio.FIRST_COMPLETED
)
if cancelled.is_set():
execution_info.remove_done_callback(_attach_session)
execution_info.cancel()
else:
# set cancelled to avoid wait task leak
cancelled.set()
await execution_info
else:
return execution_info
def execute(
tileable: TileableType,
*tileables: Tuple[TileableType, ...],
session: SyncSession = None,
wait: bool = True,
new_session_kwargs: dict = None,
show_progress: Union[bool, str] = None,
progress_update_interval=1,
**kwargs,
):
if isinstance(tileable, (tuple, list)) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
if session is None:
session = get_default_or_create(**(new_session_kwargs or dict()))
session = _ensure_sync(session)
return session.execute(
tileable,
*tileables,
wait=wait,
show_progress=show_progress,
progress_update_interval=progress_update_interval,
**kwargs,
)
async def _fetch(
tileable: TileableType,
*tileables: Tuple[TileableType, ...],
session: IsolatedAsyncSession = None,
**kwargs,
):
if isinstance(tileable, tuple) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
session = _get_isolated_session(session)
data = await session.fetch(tileable, *tileables, **kwargs)
return data[0] if len(tileables) == 0 else data
async def _fetch_infos(
tileable: TileableType,
*tileables: Tuple[TileableType, ...],
session: IsolatedAsyncSession = None,
fields: List[str] = None,
**kwargs,
):
if isinstance(tileable, tuple) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
session = _get_isolated_session(session)
data = await session.fetch_infos(tileable, *tileables, fields=fields, **kwargs)
return data[0] if len(tileables) == 0 else data
async def _decref(
tileable_key: str,
*tileable_keys: Tuple[str, ...],
session: IsolatedAsyncSession = None,
):
if isinstance(tileable_key, tuple) and len(tileable_keys) == 0:
tileable_key, tileable_keys = tileable_key[0], tileable_key[1:]
session = _get_isolated_session(session)
await session.decref(tileable_key, *tileable_keys)
def fetch(
tileable: TileableType,
*tileables: Tuple[TileableType],
session: SyncSession = None,
**kwargs,
):
if isinstance(tileable, (tuple, list)) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
if session is None:
session = get_default_session()
if session is None: # pragma: no cover
raise ValueError("No session found")
session = _ensure_sync(session)
return session.fetch(tileable, *tileables, **kwargs)
def fetch_infos(
tileable: TileableType,
*tileables: Tuple[TileableType],
fields: List[str],
session: SyncSession = None,
**kwargs,
):
if isinstance(tileable, tuple) and len(tileables) == 0:
tileable, tileables = tileable[0], tileable[1:]
if session is None:
session = get_default_session()
if session is None: # pragma: no cover
raise ValueError("No session found")
session = _ensure_sync(session)
return session.fetch_infos(tileable, *tileables, fields=fields, **kwargs)
def fetch_log(*tileables: TileableType, session: SyncSession = None, **kwargs):
if len(tileables) == 1 and isinstance(tileables[0], (list, tuple)):
tileables = tileables[0]
if session is None:
session = get_default_session()
if session is None: # pragma: no cover
raise ValueError("No session found")
session = _ensure_sync(session)
return session.fetch_log(list(tileables), **kwargs)
def ensure_isolation_created(kwargs):
loop = kwargs.pop("loop", None)
use_uvloop = kwargs.pop("use_uvloop", "auto")
try:
return get_isolation()
except KeyError:
if loop is None:
if not use_uvloop:
loop = asyncio.new_event_loop()
else:
try:
import uvloop
loop = uvloop.new_event_loop()
except ImportError:
if use_uvloop == "auto":
loop = asyncio.new_event_loop()
else: # pragma: no cover
raise
return new_isolation(loop=loop)
def _new_session_id():
return "".join(
random.choice(string.ascii_letters + string.digits) for _ in range(24)
)
async def _new_session(
address: str,
session_id: str = None,
backend: str = "maxframe",
default: bool = False,
**kwargs,
) -> AbstractSession:
if session_id is None:
session_id = _new_session_id()
session = await AsyncSession.init(
address, session_id=session_id, backend=backend, new=True, **kwargs
)
if default:
session.as_default()
return session
def new_session(
address: Union[str, ODPS] = None,
session_id: str = None,
backend: str = "maxframe",
default: bool = True,
new: bool = True,
odps_entry: Optional[ODPS] = None,
**kwargs,
) -> AbstractSession:
from maxframe_client.session import register_session_schemes
register_session_schemes()
if isinstance(address, ODPS):
address, odps_entry = None, address
# load third party extensions.
ensure_isolation_created(kwargs)
odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
if address is None:
from maxframe_client.session.consts import ODPS_SESSION_INSECURE_SCHEME
address = f"{ODPS_SESSION_INSECURE_SCHEME}://"
if session_id is None:
session_id = _new_session_id()
session = SyncSession.init(
address,
session_id=session_id,
backend=backend,
new=new,
odps_entry=odps_entry,
**kwargs,
)
if default:
session.as_default()
return session
def get_default_session() -> Optional[SyncSession]:
if AbstractSession.default is None:
return
return SyncSession.from_isolated_session(AbstractSession.default)
def clear_default_session():
AbstractSession.reset_default()
def get_default_async_session() -> Optional[AsyncSession]:
if AbstractSession.default is None:
return
return AsyncSession.from_isolated_session(AbstractSession.default)
def get_default_or_create(**kwargs):
with AbstractSession._lock:
session = AbstractSession.default
if session is None:
# no session attached, try to create one
warnings.warn(warning_msg)
odps_entry = (
kwargs.pop("odps_entry", None)
or ODPS.from_global()
or ODPS.from_environments()
)
session = new_session(odps_entry=odps_entry, **kwargs)
session.as_default()
if isinstance(session, IsolatedAsyncSession):
session = SyncSession.from_isolated_session(session)
return _ensure_sync(session)
def stop_server():
if AbstractSession.default:
SyncSession.from_isolated_session(AbstractSession.default).stop_server()
def _get_isolated_session(session: AbstractSession) -> IsolatedAsyncSession:
if hasattr(session, "_isolated_session"):
return session._isolated_session
return session
def _ensure_sync(session: AbstractSession) -> SyncSession:
if isinstance(session, SyncSession):
return session
isolated_session = _get_isolated_session(session)
return SyncSession.from_isolated_session(isolated_session)