tzrec/datasets/odps_dataset_v1.py (124 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Any, Dict, Iterator, List, Optional import common_io import pyarrow as pa from odps import options from tzrec.constant import Mode from tzrec.datasets.dataset import BaseDataset, BaseReader from tzrec.features.feature import BaseFeature from tzrec.protos import data_pb2 # pyre-ignore [16] options.read_timeout = int(os.getenv("TUNNEL_READ_TIMEOUT", "120")) TYPE2PA = { "bigint": pa.int64(), "double": pa.float32(), "boolean": pa.int64(), "string": pa.string(), "datetime": pa.int64(), } def _pa_read( reader: common_io.table.TableReader, num_records: int = 1, allow_smaller_final_batch: bool = False, ) -> Dict[str, pa.Array]: """Read the table and return the rows as a pa.array.""" reader._check_status() left_count = reader._end_pos - reader._read_pos if left_count <= 0 or (not allow_smaller_final_batch and left_count < num_records): raise common_io.exception.OutOfRangeException("No more data to read.") num_records = min(num_records, left_count) schema = reader.get_schema() result = [[] for _ in schema] for _ in range(num_records): record = reader._do_read_with_retry() reader._read_pos += 1 for i, v in enumerate(record.values): result[i].append(v) result_dict = {} for i, s in enumerate(schema): typestr = s["typestr"] result_dict[s["colname"]] = pa.array(result[i], type=TYPE2PA[typestr]) return result_dict class OdpsDatasetV1(BaseDataset): """Dataset for reading data in Odps(Maxcompute). Args: data_config (DataConfig): an instance of DataConfig. features (list): list of features. input_path (str): data input path. """ def __init__( self, data_config: data_pb2.DataConfig, features: List[BaseFeature], input_path: str, **kwargs: Any, ) -> None: super().__init__(data_config, features, input_path, **kwargs) # pyre-ignore [29] self._reader = OdpsReaderV1( input_path, self._batch_size, list(self._selected_input_names), self._data_config.drop_remainder, shuffle=self._data_config.shuffle and self._mode == Mode.TRAIN, ) self._init_input_fields() def _init_input_fields(self) -> None: """Init input fields info.""" self._input_fields = [] typedict = { "bigint": pa.int64(), "double": pa.float64(), "boolean": pa.int64(), "string": pa.string(), "datetime": pa.int64(), } for s in self._reader.schema: self._input_fields.append( pa.field(name=s["colname"], type=typedict[s["typestr"]]) ) # prevent graph-learn parse odps config hang os.environ["END_POINT"] = os.environ["ODPS_ENDPOINT"] class OdpsReaderV1(BaseReader): """Odps(Maxcompute) reader class. Args: input_path (str): data input path. batch_size (int): batch size. selected_cols (list): selection column names. drop_remainder (bool): drop last batch. shuffle (bool): shuffle data or not. shuffle_buffer_size (int): buffer size for shuffle. """ def __init__( self, input_path: str, batch_size: int, selected_cols: Optional[List[str]] = None, drop_remainder: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 32, **kwargs: Any, ) -> None: super().__init__( input_path, batch_size, selected_cols, drop_remainder, shuffle, shuffle_buffer_size, ) self.schema = [] reader = common_io.table.TableReader( self._input_path.split(",")[0], selected_cols=",".join(selected_cols or []) ) self._ordered_cols = [] for field in reader.get_schema(): if not selected_cols or field["colname"] in selected_cols: self.schema.append(field) self._ordered_cols.append(field["colname"]) reader.close() def _iter_one_table( self, input_path: str, worker_id: int = 0, num_workers: int = 1 ) -> Iterator[Dict[str, pa.Array]]: reader = common_io.table.TableReader( input_path, slice_id=worker_id, slice_count=num_workers, selected_cols=",".join(self._ordered_cols or []), ) while True: try: data = _pa_read( reader, num_records=self._batch_size, allow_smaller_final_batch=self._drop_remainder, ) except common_io.exception.OutOfRangeException: reader.close() break yield data def to_batches( self, worker_id: int = 0, num_workers: int = 1 ) -> Iterator[Dict[str, pa.Array]]: """Get batch iterator.""" for input_path in self._input_path.split(","): yield from self._iter_one_table(input_path, worker_id, num_workers)