core/maxframe_client/fetcher.py (220 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod from numbers import Integral from typing import Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd import pyarrow as pa from odps import ODPS from odps.models import ExternalVolume from maxframe.core import OBJECT_TYPE from maxframe.dataframe.core import DATAFRAME_TYPE from maxframe.io.objects import get_object_io_handler from maxframe.io.odpsio import ( ODPSTableIO, ODPSVolumeReader, TunnelTableIO, arrow_to_pandas, build_dataframe_table_meta, odps_schema_to_pandas_dtypes, ) from maxframe.protocol import ( DataFrameTableMeta, ODPSTableResultInfo, ODPSVolumeResultInfo, ResultInfo, ResultType, ) from maxframe.tensor.core import TENSOR_TYPE from maxframe.typing_ import PandasObjectTypes, TileableType from maxframe.utils import ToThreadMixin, sync_pyodps_options _result_fetchers: Dict[ResultType, Type["ResultFetcher"]] = dict() def register_fetcher(fetcher_cls: Type["ResultFetcher"]): _result_fetchers[fetcher_cls.result_type] = fetcher_cls return fetcher_cls def get_fetcher_cls(result_type: ResultType) -> Type["ResultFetcher"]: return _result_fetchers[result_type] class ResultFetcher(ABC): result_type = None def __init__(self, odps_entry: ODPS): self._odps_entry = odps_entry @abstractmethod async def update_tileable_meta( self, tileable: TileableType, info: ResultInfo, ) -> None: raise NotImplementedError @abstractmethod async def fetch( self, tileable: TileableType, info: ResultInfo, indexes: List[Union[None, Integral, slice]], ) -> Any: raise NotImplementedError @register_fetcher class NullFetcher(ResultFetcher): result_type = ResultType.NULL async def update_tileable_meta( self, tileable: TileableType, info: ResultInfo, ) -> None: return async def fetch( self, tileable: TileableType, info: ODPSTableResultInfo, indexes: List[Union[None, Integral, slice]], ) -> None: return @register_fetcher class ODPSTableFetcher(ToThreadMixin, ResultFetcher): result_type = ResultType.ODPS_TABLE def _get_table_comment(self, table_name: str) -> Optional[str]: table = self._odps_entry.get_table(table_name) return getattr(table, "comment", None) async def update_tileable_meta( self, tileable: TileableType, info: ODPSTableResultInfo, ) -> None: if ( isinstance(tileable, DATAFRAME_TYPE) and tileable.dtypes is None and info.table_meta is not None ): if info.table_meta.pd_column_dtypes is not None: tileable.refresh_from_table_meta(info.table_meta) else: # need to get meta directly from table table = self._odps_entry.get_table(info.full_table_name) pd_dtypes = odps_schema_to_pandas_dtypes(table.table_schema).drop( info.table_meta.table_index_column_names ) tileable.refresh_from_dtypes(pd_dtypes) if tileable.shape and any(pd.isna(x) for x in tileable.shape): part_specs = [None] if not info.partition_specs else info.partition_specs with sync_pyodps_options(): table = self._odps_entry.get_table(info.full_table_name) if isinstance(tileable, DATAFRAME_TYPE) and tileable.dtypes is None: dtypes = odps_schema_to_pandas_dtypes(table.table_schema) tileable.refresh_from_dtypes(dtypes) part_sessions = TunnelTableIO.create_download_sessions( self._odps_entry, info.full_table_name, part_specs ) total_records = sum(session.count for session in part_sessions.values()) new_shape_list = list(tileable.shape) new_shape_list[0] = total_records tileable.params = {"shape": tuple(new_shape_list)} @staticmethod def _align_selection_with_shape( row_sel: slice, shape: Tuple[Optional[int], ...] ) -> dict: size = shape[0] if not row_sel.start and not row_sel.stop: return {} is_reversed = row_sel.step is not None and row_sel.step < 0 read_kw = { "start": row_sel.start, "stop": row_sel.stop, "reverse_range": is_reversed, } if pd.isna(size): return read_kw if is_reversed and row_sel.start is not None: read_kw["start"] = min(size - 1, row_sel.start) if not is_reversed and row_sel.stop is not None: read_kw["stop"] = min(size, row_sel.stop) return read_kw def _read_single_source( self, table_meta: DataFrameTableMeta, info: ODPSTableResultInfo, indexes: List[Union[None, Integral, slice]], shape: Tuple[Optional[int], ...], ): table_io = ODPSTableIO(self._odps_entry) read_kw = {} row_step = None if indexes: if len(indexes) < 2: indexes += [None] row_sel, col_sel = indexes if isinstance(row_sel, slice): row_step = row_sel.step read_kw = self._align_selection_with_shape(row_sel, shape) elif isinstance(row_sel, int): read_kw["start"] = row_sel read_kw["stop"] = row_sel + 1 row_step = None elif row_sel is not None: # pragma: no cover raise NotImplementedError(f"Does not support row index {row_sel!r}") if isinstance(col_sel, (int, slice)): data_cols = table_meta.table_column_names[col_sel] if isinstance(col_sel, int): data_cols = [data_cols] read_kw["columns"] = table_meta.table_index_column_names + data_cols elif col_sel is not None: # pragma: no cover raise NotImplementedError(f"Does not support column index {row_sel!r}") with table_io.open_reader( info.full_table_name, info.partition_specs, **read_kw ) as reader: result = reader.read_all() reader_count = result.num_rows if not row_step: return result if row_step >= 0: slice_start = 0 slice_stop = reader_count else: slice_start = reader_count - 1 slice_stop = None return result[slice_start:slice_stop:row_step] async def fetch( self, tileable: TileableType, info: ODPSTableResultInfo, indexes: List[Union[None, Integral, slice]], ) -> PandasObjectTypes: table_meta = build_dataframe_table_meta(tileable) arrow_table: pa.Table = await self.to_thread( self._read_single_source, table_meta, info, indexes, tileable.shape ) return arrow_to_pandas(arrow_table, table_meta) @register_fetcher class ODPSVolumeFetcher(ToThreadMixin, ResultFetcher): result_type = ResultType.ODPS_VOLUME async def update_tileable_meta( self, tileable: TileableType, info: ODPSVolumeResultInfo, ) -> None: return async def _fetch_object( self, tileable: TileableType, info: ODPSVolumeResultInfo, indexes: List[Union[Integral, slice]], ) -> Any: def volume_fetch_func(): reader = ODPSVolumeReader( self._odps_entry, info.volume_name, info.volume_path, replace_internal_host=True, ) io_handler = get_object_io_handler(tileable)() return io_handler.read_object(reader, tileable, indexes) volume = await self.to_thread(self._odps_entry.get_volume, info.volume_name) if isinstance(volume, ExternalVolume): return await self.to_thread(volume_fetch_func) else: raise NotImplementedError(f"Volume type {type(volume)} not supported") async def fetch( self, tileable: TileableType, info: ODPSVolumeResultInfo, indexes: List[Union[Integral, slice]], ) -> Any: if isinstance(tileable, (OBJECT_TYPE, TENSOR_TYPE)): return await self._fetch_object(tileable, info, indexes) raise NotImplementedError(f"Fetching {type(tileable)} not implemented")