core/maxframe_client/session/odps.py (552 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
import copy
import logging
import time
import weakref
from numbers import Integral
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from urllib.parse import urlparse
import numpy as np
import pandas as pd
from odps import ODPS
from odps import options as odps_options
from odps.console import in_ipython_frontend
from maxframe.config import options
from maxframe.core import Entity, TileableGraph, build_fetch, enter_mode
from maxframe.core.operator import Fetch
from maxframe.dataframe import read_odps_table
from maxframe.dataframe.core import DATAFRAME_TYPE, SERIES_TYPE
from maxframe.dataframe.datasource import PandasDataSourceOperator
from maxframe.dataframe.datasource.read_odps_table import DataFrameReadODPSTable
from maxframe.errors import (
MaxFrameError,
NoTaskServerResponseError,
SessionAlreadyClosedError,
)
from maxframe.io.objects import get_object_io_handler
from maxframe.io.odpsio import (
ODPSTableIO,
ODPSVolumeWriter,
pandas_to_arrow,
pandas_to_odps_schema,
)
from maxframe.protocol import (
DagInfo,
DagStatus,
ODPSTableResultInfo,
ODPSVolumeResultInfo,
ResultInfo,
SessionInfo,
)
from maxframe.session import (
AbstractSession,
ExecutionInfo,
IsolatedAsyncSession,
Profiling,
Progress,
)
from maxframe.tensor.datasource import ArrayDataSource
from maxframe.typing_ import TileableType
from maxframe.utils import (
ToThreadMixin,
build_session_volume_name,
build_temp_table_name,
get_default_table_properties,
str_to_bool,
sync_pyodps_options,
)
from ..clients.framedriver import FrameDriverClient
from ..fetcher import get_fetcher_cls
from .consts import RESTFUL_SESSION_INSECURE_SCHEME, RESTFUL_SESSION_SECURE_SCHEME
from .graph import gen_submit_tileable_graph
logger = logging.getLogger(__name__)
class MaxFrameServiceCaller(metaclass=abc.ABCMeta):
def get_settings_to_upload(self) -> Dict[str, Any]:
sql_settings = (odps_options.sql.settings or {}).copy()
sql_settings.update(options.sql.settings or {})
quota_name = options.session.quota_name or getattr(
odps_options, "quota_name", None
)
quota_settings = {
sql_settings.get("odps.task.wlm.quota", None),
options.spe.task.settings.get("odps.task.wlm.quota", None),
options.pythonpack.task.settings.get("odps.task.wlm.quota", None),
quota_name,
}.difference([None])
if len(quota_settings) >= 2:
raise ValueError(
"Quota settings are conflicting: %s" % ", ".join(sorted(quota_settings))
)
elif len(quota_settings) == 1:
quota_name = quota_settings.pop()
lifecycle = options.session.table_lifecycle or odps_options.lifecycle
temp_lifecycle = (
options.session.temp_table_lifecycle or odps_options.temp_lifecycle
)
enable_schema = options.session.enable_schema
default_schema = options.session.default_schema
if hasattr(self, "_odps_entry"):
default_schema = default_schema or self._odps_entry.schema
# use flags in sql settings
if sql_settings.get("odps.default.schema"):
default_schema = sql_settings["odps.default.schema"]
if str_to_bool(
sql_settings.get("odps.namespace.schema") or "false"
) or str_to_bool(
sql_settings.get("odps.sql.allow.namespace.schema") or "false"
):
enable_schema = True
mf_settings = dict(options.to_dict(remote_only=True).items())
mf_settings["sql.settings"] = sql_settings
mf_settings["session.table_lifecycle"] = lifecycle
mf_settings["session.temp_table_lifecycle"] = temp_lifecycle
mf_settings["session.quota_name"] = quota_name
if enable_schema is not None:
mf_settings["session.enable_schema"] = enable_schema
if options.session.enable_high_availability is None:
mf_settings["session.enable_high_availability"] = not in_ipython_frontend()
mf_settings["session.default_schema"] = default_schema or "default"
return mf_settings
@abc.abstractmethod
def create_session(self) -> SessionInfo:
raise NotImplementedError
@abc.abstractmethod
def delete_session(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def submit_dag(
self,
dag: TileableGraph,
managed_input_infos: Dict[str, ResultInfo],
new_settings: Dict[str, Any] = None,
) -> DagInfo:
raise NotImplementedError
@abc.abstractmethod
def get_dag_info(self, dag_id: str) -> DagInfo:
raise NotImplementedError
@abc.abstractmethod
def cancel_dag(self, dag_id: str) -> DagInfo:
raise NotImplementedError
@abc.abstractmethod
def decref(self, tileable_keys: List[str]) -> None:
raise NotImplementedError
def get_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
return None
class MaxFrameSession(ToThreadMixin, IsolatedAsyncSession):
_odps_entry: Optional[ODPS]
_tileable_to_infos: Mapping[TileableType, ResultInfo]
@classmethod
async def init(
cls,
address: str,
session_id: Optional[str] = None,
backend: str = None,
odps_entry: Optional[ODPS] = None,
timeout: Optional[float] = None,
**kwargs,
) -> "AbstractSession":
session_obj = cls(
address, session_id, odps_entry=odps_entry, timeout=timeout, **kwargs
)
await session_obj._init(address)
return session_obj
def __init__(
self,
address: str,
session_id: str,
odps_entry: Optional[ODPS] = None,
timeout: Optional[float] = None,
**kwargs,
):
super().__init__(address, session_id)
self.timeout = timeout
self._odps_entry = odps_entry or ODPS.from_global() or ODPS.from_environments()
self._tileable_to_infos = weakref.WeakKeyDictionary()
self._caller = self._create_caller(odps_entry, address, **kwargs)
self._last_settings = None
self._pull_interval = 1 if in_ipython_frontend() else 3
self._replace_internal_host = kwargs.get("replace_internal_host", True)
@classmethod
def _create_caller(
cls, odps_entry: ODPS, address: str, **kwargs
) -> MaxFrameServiceCaller:
raise NotImplementedError
async def _init(self, _address: str):
session_info = await self.ensure_async_call(self._caller.create_session)
self._last_settings = copy.deepcopy(self._caller.get_settings_to_upload())
self._session_id = session_info.session_id
await self._show_logview_address()
def _upload_and_get_table_read_tileable(
self, t: TileableType
) -> Optional[TileableType]:
table_schema, table_meta = pandas_to_odps_schema(t, unknown_as_string=True)
if self._odps_entry.exist_table(table_meta.table_name):
self._odps_entry.delete_table(
table_meta.table_name, hints=options.sql.settings
)
table_name = build_temp_table_name(self.session_id, t.key)
table_obj = self._odps_entry.create_table(
table_name,
table_schema,
lifecycle=options.session.temp_table_lifecycle,
hints=options.sql.settings,
if_not_exists=True,
table_properties=options.session.temp_table_properties
or get_default_table_properties(),
)
data = t.op.get_data()
batch_size = options.session.upload_batch_size
if len(data):
table_client = ODPSTableIO(self._odps_entry)
with table_client.open_writer(table_obj.full_table_name) as writer:
for batch_start in range(0, len(data), batch_size):
if isinstance(data, pd.Index):
batch = data[batch_start : batch_start + batch_size]
else:
batch = data.iloc[batch_start : batch_start + batch_size]
arrow_batch, _ = pandas_to_arrow(batch)
writer.write(arrow_batch)
read_tileable = read_odps_table(
table_obj.full_table_name,
columns=table_meta.table_column_names,
index_col=table_meta.table_index_column_names,
output_type=table_meta.type,
)
if isinstance(read_tileable, DATAFRAME_TYPE):
if list(read_tileable.dtypes.index) != list(t.dtypes.index):
read_tileable.columns = list(t.dtypes.index)
elif isinstance(read_tileable, SERIES_TYPE):
if read_tileable.name != t.name:
read_tileable.name = t.name
else: # INDEX_TYPE
if list(read_tileable.names) != list(t.names):
read_tileable.rename(t.names, inplace=True)
read_tileable._key = t.key
read_tileable.params = t.params
return read_tileable.data
def _upload_and_get_vol_read_tileable(
self, t: TileableType
) -> Optional[TileableType]:
vol_name = build_session_volume_name(self.session_id)
writer = ODPSVolumeWriter(
self._odps_entry,
vol_name,
t.key,
replace_internal_host=self._replace_internal_host,
)
io_handler = get_object_io_handler(t)
io_handler().write_object(writer, t, t.op.data)
return build_fetch(t).data
def _upload_and_get_read_tileable(self, t: TileableType) -> Optional[TileableType]:
if (
not isinstance(t.op, (ArrayDataSource, PandasDataSourceOperator))
or t.op.get_data() is None
or t.inputs
):
return None
with sync_pyodps_options():
if isinstance(t.op, PandasDataSourceOperator):
return self._upload_and_get_table_read_tileable(t)
else:
return self._upload_and_get_vol_read_tileable(t)
@enter_mode(kernel=True, build=True)
def _scan_and_replace_local_sources(
self, graph: TileableGraph
) -> Dict[TileableType, TileableType]:
"""Replaces Pandas data sources with temp table sources in the graph"""
replacements = dict()
for t in graph:
replaced = self._upload_and_get_read_tileable(t)
if replaced is None:
continue
replacements[t] = replaced
for src, replaced in replacements.items():
successors = list(graph.successors(src))
graph.remove_node(src)
graph.add_node(replaced)
for pred in replaced.inputs or ():
graph.add_node(pred)
graph.add_edge(pred, replaced)
for succ in successors:
graph.add_edge(replaced, succ)
succ.op._set_inputs([replacements.get(t, t) for t in succ.inputs])
graph.results = [replacements.get(t, t) for t in graph.results]
return replacements
@enter_mode(kernel=True, build=True)
def _get_input_infos(self, tileables: List[TileableType]) -> Dict[str, ResultInfo]:
"""Generate ResultInfo structs from generated temp tables"""
vol_name = build_session_volume_name(self.session_id)
infos = dict()
for t in tileables:
key = t.key
if isinstance(t.op, DataFrameReadODPSTable):
infos[key] = ODPSTableResultInfo(full_table_name=t.op.table_name)
else:
if isinstance(t.op, Fetch):
infos[key] = ODPSVolumeResultInfo(
volume_name=vol_name, volume_path=t.key
)
elif t.inputs and isinstance(t.inputs[0].op, DataFrameReadODPSTable):
t = t.inputs[0]
infos[key] = ODPSTableResultInfo(full_table_name=t.op.table_name)
return infos
def _get_diff_settings(self) -> Dict[str, Any]:
new_settings = self._caller.get_settings_to_upload()
if not self._last_settings: # pragma: no cover
self._last_settings = copy.deepcopy(new_settings)
return new_settings
if self._last_settings.get("session.quota_name", None) != new_settings.get(
"session.quota_name", None
):
raise ValueError("Quota name cannot be changed after sessions are created")
update = dict()
for k in new_settings.keys():
old_item = self._last_settings.get(k)
new_item = new_settings.get(k)
try:
if old_item != new_item:
update[k] = new_item
except: # noqa: E722 # nosec # pylint: disable=bare-except
update[k] = new_item
self._last_settings = copy.deepcopy(new_settings)
return update
async def execute(self, *tileables, **kwargs) -> ExecutionInfo:
tileables = [
tileable.data if isinstance(tileable, Entity) else tileable
for tileable in tileables
]
tileable_to_copied = dict()
tileable_graph, to_execute_tileables = gen_submit_tileable_graph(
self, tileables, tileable_to_copied
)
source_replacements = self._scan_and_replace_local_sources(tileable_graph)
# we need to manage uploaded data sources with refcounting mechanism
# as nodes in tileable_graph are copied, we need to use original nodes
copied_to_tileable = {v: k for k, v in tileable_to_copied.items()}
for replaced_src in source_replacements.keys():
copied_to_tileable[replaced_src]._attach_session(self)
replaced_infos = self._get_input_infos(list(source_replacements.values()))
dag_info = await self.ensure_async_call(
self._caller.submit_dag,
tileable_graph,
replaced_infos,
self._get_diff_settings(),
)
await self._show_logview_address(dag_info.dag_id)
progress = Progress()
profiling = Profiling()
aio_task = asyncio.create_task(
self._run_in_background(dag_info, to_execute_tileables, progress)
)
return ExecutionInfo(
aio_task,
progress,
profiling,
asyncio.get_running_loop(),
to_execute_tileables,
)
async def _run_in_background(
self, dag_info: DagInfo, tileables: List, progress: Progress
):
start_time = time.time()
session_id = dag_info.session_id
dag_id = dag_info.dag_id
server_no_response_time = None
with enter_mode(build=True, kernel=True):
key_to_tileables = {t.key: t for t in tileables}
timeout_val = 0.1
try:
while True:
elapsed_time = time.time() - start_time
next_timeout_val = min(timeout_val * 2, self._pull_interval)
timeout_val = (
min(self.timeout - elapsed_time, next_timeout_val)
if self.timeout
else next_timeout_val
)
if timeout_val <= 0:
raise TimeoutError("Running DAG timed out")
try:
dag_info: DagInfo = await self.ensure_async_call(
self._caller.get_dag_info, dag_id
)
server_no_response_time = None
except (NoTaskServerResponseError, SessionAlreadyClosedError) as ex:
# when we receive SessionAlreadyClosedError after NoTaskServerResponseError
# is received, it is possible that task server is restarted and
# SessionAlreadyClosedError might be flaky. Otherwise, the error
# should be raised.
if (
isinstance(ex, SessionAlreadyClosedError)
and not server_no_response_time
):
raise
server_no_response_time = server_no_response_time or time.time()
if (
time.time() - server_no_response_time
> options.client.task_restart_timeout
):
raise MaxFrameError(
"Failed to get valid response from service. "
f"Session {self._session_id}."
) from None
await asyncio.sleep(timeout_val)
continue
if dag_info is None:
raise SystemError(
f"Cannot find DAG with ID {dag_id} in session {session_id}"
)
progress.value = dag_info.progress
if dag_info.status != DagStatus.RUNNING:
break
await asyncio.sleep(timeout_val)
except asyncio.CancelledError:
dag_info = await self.ensure_async_call(self._caller.cancel_dag, dag_id)
if dag_info.status != DagStatus.CANCELLED: # pragma: no cover
raise
finally:
if dag_info.status == DagStatus.SUCCEEDED:
progress.value = 1.0
elif dag_info.status == DagStatus.FAILED:
dag_info.error_info.reraise()
if dag_info.status in (DagStatus.RUNNING, DagStatus.CANCELLED):
return
for key, result_info in dag_info.tileable_to_result_infos.items():
t = key_to_tileables[key]
fetcher = get_fetcher_cls(result_info.result_type)(self._odps_entry)
await fetcher.update_tileable_meta(t, result_info)
self._tileable_to_infos[t] = result_info
def _get_data_tileable_and_indexes(
self, tileable: TileableType
) -> Tuple[TileableType, List[Union[slice, Integral]]]:
from maxframe.dataframe.indexing.iloc import (
DataFrameIlocGetItem,
SeriesIlocGetItem,
)
from maxframe.tensor.indexing import TensorIndex
slice_op_types = TensorIndex, DataFrameIlocGetItem, SeriesIlocGetItem
if isinstance(tileable, Entity):
tileable = tileable.data
indexes = None
while tileable not in self._tileable_to_infos:
# if tileable's op is slice, try to check input
if isinstance(tileable.op, slice_op_types):
indexes = tileable.op.indexes
tileable = tileable.inputs[0]
if not all(isinstance(index, (slice, Integral)) for index in indexes):
raise ValueError("Only support fetch data slices")
else:
raise ValueError(f"Cannot fetch unexecuted tileable: {tileable!r}")
return tileable, indexes
async def fetch(self, *tileables, **kwargs) -> list:
results = []
tileables = [
tileable.data if isinstance(tileable, Entity) else tileable
for tileable in tileables
]
with enter_mode(build=True):
for tileable in tileables:
data_tileable, indexes = self._get_data_tileable_and_indexes(tileable)
info = self._tileable_to_infos[data_tileable]
fetcher = get_fetcher_cls(info.result_type)(self._odps_entry)
results.append(await fetcher.fetch(data_tileable, info, indexes))
return results
async def decref(self, *tileable_keys):
return await self.ensure_async_call(self._caller.decref, list(tileable_keys))
async def destroy(self):
await self.ensure_async_call(self._caller.delete_session)
await super().destroy()
async def _get_ref_counts(self) -> Dict[str, int]:
pass
async def fetch_tileable_op_logs(
self,
tileable_op_key: str,
offsets: Union[Dict[str, List[int]], str, int],
sizes: Union[Dict[str, List[int]], str, int],
) -> Dict:
pass
async def get_total_n_cpu(self):
pass
async def get_cluster_versions(self) -> List[str]:
raise NotImplementedError
async def get_web_endpoint(self) -> Optional[str]:
raise NotImplementedError
async def create_remote_object(
self, session_id: str, name: str, object_cls, *args, **kwargs
):
raise NotImplementedError
async def get_remote_object(self, session_id: str, name: str):
raise NotImplementedError
async def destroy_remote_object(self, session_id: str, name: str):
raise NotImplementedError
async def create_mutable_tensor(
self,
shape: tuple,
dtype: Union[np.dtype, str],
name: str = None,
default_value: Union[int, float] = 0,
chunk_size: Union[int, Tuple] = None,
):
raise NotImplementedError
async def get_mutable_tensor(self, name: str):
raise NotImplementedError
async def get_logview_address(self, hours=None) -> Optional[str]:
return await self.get_dag_logview_address(None, hours)
async def get_dag_logview_address(self, dag_id=None, hours=None) -> Optional[str]:
return await self.ensure_async_call(
self._caller.get_logview_address, dag_id, hours
)
async def _show_logview_address(self, dag_id=None, hours=None):
identity = f"Session ID: {self._session_id}"
if dag_id:
identity += f", DAG ID: {dag_id}"
logview_addr = await self.get_dag_logview_address(dag_id, hours)
if logview_addr:
logger.info("%s, Logview: %s", identity, logview_addr)
else:
logger.info("%s, Logview address does not exist", identity)
class MaxFrameRestCaller(MaxFrameServiceCaller):
_client: FrameDriverClient
_session_id: Optional[str]
def __init__(self, odps_entry: ODPS, client: FrameDriverClient):
self._odps_entry = odps_entry
self._client = client
self._session_id = None
async def create_session(self) -> SessionInfo:
info = await self._client.create_session(options.to_dict(remote_only=True))
self._session_id = info.session_id
return info
async def delete_session(self) -> None:
await self._client.delete_session(self._session_id)
async def submit_dag(
self,
dag: TileableGraph,
managed_input_infos: Dict[str, ResultInfo] = None,
new_settings: Dict[str, Any] = None,
) -> DagInfo:
return await self._client.submit_dag(
self._session_id, dag, managed_input_infos, new_settings=new_settings
)
async def get_dag_info(self, dag_id: str) -> DagInfo:
return await self._client.get_dag_info(self._session_id, dag_id)
async def cancel_dag(self, dag_id: str) -> DagInfo:
return await self._client.cancel_dag(self._session_id, dag_id)
async def decref(self, tileable_keys: List[str]) -> None:
return await self._client.decref(self._session_id, tileable_keys)
class MaxFrameRestSession(MaxFrameSession):
schemes = [RESTFUL_SESSION_INSECURE_SCHEME, RESTFUL_SESSION_SECURE_SCHEME]
def __init__(
self,
address: str,
session_id: str,
odps_entry: Optional[ODPS] = None,
timeout: Optional[float] = None,
new: bool = True,
**kwargs,
):
parsed_endpoint = urlparse(address)
scheme = (
"http"
if parsed_endpoint.scheme == RESTFUL_SESSION_INSECURE_SCHEME
else "https"
)
real_endpoint = address.replace(f"{parsed_endpoint.scheme}://", f"{scheme}://")
super().__init__(
real_endpoint, session_id, odps_entry=odps_entry, timeout=timeout, **kwargs
)
@classmethod
def _create_caller(cls, odps_entry: ODPS, address: str, **kwargs):
return MaxFrameRestCaller(odps_entry, FrameDriverClient(address))
def register_session_schemes(overwrite: bool = False):
MaxFrameRestSession.register_schemes(overwrite=overwrite)