tzrec/datasets/odps_dataset.py (504 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import random import threading import time from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Tuple import pyarrow as pa import urllib3 from alibabacloud_credentials.client import Client as CredClient from odps import ODPS from odps.accounts import ( AliyunAccount, BaseAccount, CredentialProviderAccount, ) from odps.apis.storage_api import ( ArrowReader, Compression, ReadRowsRequest, SessionRequest, SessionStatus, SplitOptions, Status, StorageApiArrowClient, TableBatchScanRequest, TableBatchWriteRequest, WriteRowsRequest, ) from odps.errors import ODPSError from torch import distributed as dist from tzrec.constant import Mode from tzrec.datasets.dataset import BaseDataset, BaseReader, BaseWriter from tzrec.datasets.utils import calc_slice_position, remove_nullable from tzrec.features.feature import BaseFeature from tzrec.protos import data_pb2 from tzrec.utils import dist_util from tzrec.utils.logging_util import logger ODPS_READ_SESSION_EXPIRED_TIME = 18 * 3600 TYPE_TABLE_TO_PA = { "BIGINT": pa.int64(), "DOUBLE": pa.float64(), "FLOAT": pa.float32(), "STRING": pa.string(), "INT": pa.int32(), "ARRAY<BIGINT>": pa.list_(pa.int64()), "ARRAY<DOUBLE>": pa.list_(pa.float64()), "ARRAY<FLOAT>": pa.list_(pa.float32()), "ARRAY<STRING>": pa.list_(pa.string()), "ARRAY<INT>": pa.list_(pa.int32()), "ARRAY<ARRAY<BIGINT>>": pa.list_(pa.list_(pa.int64())), "ARRAY<ARRAY<DOUBLE>>": pa.list_(pa.list_(pa.float64())), "ARRAY<ARRAY<FLOAT>>": pa.list_(pa.list_(pa.float32())), "ARRAY<ARRAY<STRING>>": pa.list_(pa.list_(pa.string())), "ARRAY<ARRAY<INT>>": pa.list_(pa.list_(pa.int32())), "MAP<STRING,BIGINT>": pa.map_(pa.string(), pa.int64()), "MAP<STRING,DOUBLE>": pa.map_(pa.string(), pa.float64()), "MAP<STRING,FLOAT>": pa.map_(pa.string(), pa.float32()), "MAP<STRING,STRING>": pa.map_(pa.string(), pa.string()), "MAP<STRING,INT>": pa.map_(pa.string(), pa.int32()), "MAP<BIGINT,BIGINT>": pa.map_(pa.int64(), pa.int64()), "MAP<BIGINT,DOUBLE>": pa.map_(pa.int64(), pa.float64()), "MAP<BIGINT,FLOAT>": pa.map_(pa.int64(), pa.float32()), "MAP<BIGINT,STRING>": pa.map_(pa.int64(), pa.string()), "MAP<BIGINT,INT>": pa.map_(pa.int64(), pa.int32()), "MAP<INT,BIGINT>": pa.map_(pa.int32(), pa.int64()), "MAP<INT,DOUBLE>": pa.map_(pa.int32(), pa.float64()), "MAP<INT,FLOAT>": pa.map_(pa.int32(), pa.float32()), "MAP<INT,STRING>": pa.map_(pa.int32(), pa.string()), "MAP<INT,INT>": pa.map_(pa.int32(), pa.int32()), } def _get_compression_type(compression_name: str) -> Compression: type_names = [x.name for x in Compression] if compression_name in type_names: return Compression[compression_name] else: raise ValueError( f"Unknown compression type: {compression_name}, available {type_names}" ) def _type_pa_to_table(pa_type: pa.DataType) -> str: """PyArrow type to MaxCompute Table type.""" mc_type = None pa_type = remove_nullable(pa_type) for k, v in TYPE_TABLE_TO_PA.items(): # list<element: int64> and list<item: int64> is equal if v == pa_type: mc_type = k break if mc_type: return mc_type else: raise RuntimeError(f"{pa_type} is not supported now.") def _parse_odps_config_file(odps_config_path: str) -> Tuple[str, str, str]: """Parse odps config file.""" if os.path.exists(odps_config_path): odps_config = {} with open(odps_config_path, "r") as f: for line in f.readlines(): values = line.split("=", 1) if len(values) == 2: odps_config[values[0]] = values[1].strip() else: raise ValueError("No such file: %s" % odps_config_path) try: access_id = odps_config["access_id"] access_key = odps_config["access_key"] end_point = odps_config["end_point"] except KeyError as err: raise IOError( "%s key does not exist in the %s file." % (str(err), odps_config_path) ) from err return access_id, access_key, end_point def _create_odps_account() -> Tuple[BaseAccount, str]: account = None sts_token = None if "ODPS_CONFIG_FILE_PATH" in os.environ: account_id, account_key, odps_endpoint = _parse_odps_config_file( os.environ["ODPS_CONFIG_FILE_PATH"] ) account = AliyunAccount(account_id, account_key) elif ( "ALIBABA_CLOUD_CREDENTIALS_URI" in os.environ or "ALIBABA_CLOUD_SECURITY_TOKEN" in os.environ or "ALIBABA_CLOUD_CREDENTIALS_FILE" in os.environ or "ALIBABA_CLOUD_ECS_METADATA" in os.environ ): credentials_client = CredClient() # prevent too much request to credential server after forked credential = credentials_client.get_credential() account_id = credential.access_key_id account_key = credential.access_key_secret sts_token = credential.security_token account = CredentialProviderAccount(credentials_client) try: odps_endpoint = os.environ["ODPS_ENDPOINT"] except KeyError as err: raise RuntimeError( "ODPS_ENDPOINT does not exist in environment variables." ) from err else: account_id, account_key, odps_endpoint = _parse_odps_config_file( os.path.join(os.getenv("HOME", "/home/admin"), ".odps_config.ini") ) account = AliyunAccount(account_id, account_key) # prevent graph-learn parse odps config hang os.environ["ACCESS_ID"] = account_id os.environ["ACCESS_KEY"] = account_key os.environ["END_POINT"] = odps_endpoint if sts_token: os.environ["STS_TOKEN"] = sts_token return account, odps_endpoint def _parse_table_path(odps_table_path: str) -> Tuple[str, str, Optional[List[str]]]: """Method that parse odps table path.""" str_list = odps_table_path.split("/") if len(str_list) < 5 or str_list[3] != "tables": raise ValueError( f"'{odps_table_path}' is invalid, please refer:" "'odps://${your_projectname}/tables/${table_name}/${pt_1}/${pt_2}&${pt_1}/${pt_2}'" ) table_partition = "/".join(str_list[5:]) if not table_partition: table_partitions = None else: table_partitions = table_partition.split("&") return str_list[2], str_list[4], table_partitions def _read_rows_arrow_with_retry( client: StorageApiArrowClient, read_req: ReadRowsRequest, ) -> ArrowReader: max_retry_count = 3 retry_cnt = 0 while True: try: reader = client.read_rows_arrow(read_req) except ODPSError as e: if retry_cnt >= max_retry_count: raise e retry_cnt += 1 time.sleep(random.choice([5, 9, 12])) continue break return reader def _reader_iter( client: StorageApiArrowClient, sess_reqs: List[SessionRequest], worker_id: int, num_workers: int, batch_size: int, drop_redundant_bs_eq_one: bool, compression: Compression, ) -> Iterator[pa.RecordBatch]: num_sess = len(sess_reqs) remain_row_count = 0 for i, sess_req in enumerate(sess_reqs): while True: scan_resp = client.get_read_session(sess_req) if scan_resp.session_status == SessionStatus.INIT: time.sleep(1) continue break start, end, remain_row_count = calc_slice_position( # pyre-ignore [6] scan_resp.record_count, worker_id, num_workers, batch_size, drop_redundant_bs_eq_one if i == num_sess - 1 else False, remain_row_count, ) if start == end: return offset = 0 retry_cnt = 0 read_req = ReadRowsRequest( session_id=sess_req.session_id, row_index=start, row_count=end - start, max_batch_rows=min(batch_size, 20000), compression=compression, ) reader = _read_rows_arrow_with_retry(client, read_req) max_retry_count = 5 while True: try: read_data = reader.read() retry_cnt = 0 # pyre-ignore [66] except (urllib3.exceptions.HTTPError, pa.lib.ArrowInvalid) as e: if retry_cnt >= max_retry_count: raise e retry_cnt += 1 read_req = ReadRowsRequest( session_id=sess_req.session_id, row_index=start + offset, row_count=end - start - offset, max_batch_rows=min(batch_size, 20000), compression=compression, ) reader = _read_rows_arrow_with_retry(client, read_req) continue if read_data is None: break else: retry_cnt = 0 offset += len(read_data) yield read_data def _refresh_sessions_daemon(sess_id_to_cli: Dict[str, StorageApiArrowClient]) -> None: start_time = time.time() while True: if time.time() - start_time > ODPS_READ_SESSION_EXPIRED_TIME: for session_id, client in sess_id_to_cli.items(): logger.info(f"refresh session: {session_id}") client.get_read_session(SessionRequest(session_id, refresh=True)) start_time = time.time() time.sleep(5) class OdpsDataset(BaseDataset): """Dataset for reading data in Odps(Maxcompute). Args: data_config (DataConfig): an instance of DataConfig. features (list): list of features. input_path (str): data input path. """ def __init__( self, data_config: data_pb2.DataConfig, features: List[BaseFeature], input_path: str, **kwargs: Any, ) -> None: if int(os.environ.get("WORLD_SIZE", 1)) > 1: assert dist.is_initialized(), ( "You should initialize distribute group first." ) super().__init__(data_config, features, input_path, **kwargs) # pyre-ignore [29] self._reader = OdpsReader( input_path, self._batch_size, list(self._selected_input_names) if self._selected_input_names else None, self._data_config.drop_remainder, is_orderby_partition=self._data_config.is_orderby_partition, quota_name=self._data_config.odps_data_quota_name, drop_redundant_bs_eq_one=self._mode != Mode.PREDICT, compression=self._data_config.odps_data_compression, ) self._init_input_fields() class OdpsReader(BaseReader): """Odps(Maxcompute) reader class. Args: input_path (str): data input path. batch_size (int): batch size. selected_cols (list): selection column names. drop_remainder (bool): drop last batch less than batch_size. shuffle (bool): shuffle data or not. shuffle_buffer_size (int): buffer size for shuffle. is_orderby_partition (bool): read data order by table partitions or not. quota_name (str): storage api quota name. drop_redundant_bs_eq_one (bool): drop last redundant batch with batch_size equal one to prevent train_eval hung. compression (str): storage api data compression name. """ def __init__( self, input_path: str, batch_size: int, selected_cols: Optional[List[str]] = None, drop_remainder: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 32, is_orderby_partition: bool = False, quota_name: str = "pay-as-you-go", drop_redundant_bs_eq_one: bool = False, compression: str = "LZ4_FRAME", **kwargs: Any, ) -> None: super().__init__( input_path, batch_size, selected_cols, drop_remainder, shuffle, shuffle_buffer_size, ) self._is_orderby_partition = is_orderby_partition self._quota_name = quota_name self._compression = _get_compression_type(compression) os.environ["STORAGE_API_QUOTA_NAME"] = quota_name self._drop_redundant_bs_eq_one = drop_redundant_bs_eq_one self._account, self._odps_endpoint = _create_odps_account() self._proj_to_o = {} self._table_to_cli = {} self._input_to_sess = {} self._init_client() self.schema = [] self._ordered_cols = [] _, table_name, _ = _parse_table_path(self._input_path.split(",")[0]) table = self._table_to_cli[table_name].table for column in table.schema.simple_columns: if not self._selected_cols or column.name in self._selected_cols: column_type = str(column.type).upper() if column_type not in TYPE_TABLE_TO_PA: raise ValueError( f"column [{column.name}] with dtype {column.type} " "is not supported now." ) self.schema.append( pa.field(column.name, TYPE_TABLE_TO_PA[str(column.type).upper()]) ) self._ordered_cols.append(column.name) self._init_session() def _init_client(self) -> None: """Init storage api client.""" for input_path in self._input_path.split(","): project, table_name, _ = _parse_table_path(input_path) if project not in self._proj_to_o: self._proj_to_o[project] = ODPS( account=self._account, project=project, endpoint=self._odps_endpoint, ) if table_name not in self._table_to_cli: o = self._proj_to_o[project] self._table_to_cli[table_name] = StorageApiArrowClient( odps=o, table=o.get_table(table_name), quota_name=self._quota_name ) def _init_session(self) -> None: """Init table scan session.""" sess_id_to_cli = {} for input_path in self._input_path.split(","): session_ids = [] _, table_name, partitions = _parse_table_path(input_path) client = self._table_to_cli[table_name] if self._is_orderby_partition and partitions is not None: splited_partitions = [[x] for x in partitions] else: splited_partitions = [partitions] for partitions in splited_partitions: if int(os.environ.get("RANK", 0)) == 0: scan_req = TableBatchScanRequest( split_options=SplitOptions(split_mode="RowOffset"), required_data_columns=self._ordered_cols, required_partitions=partitions, ) scan_resp = client.create_read_session(scan_req) session_ids.append(scan_resp.session_id) sess_id_to_cli[scan_resp.session_id] = client else: session_ids.append(None) if dist.is_initialized(): dist.broadcast_object_list(session_ids) self._input_to_sess[input_path] = [ SessionRequest(session_id=x) for x in session_ids ] # refresh session if int(os.environ.get("RANK", 0)) == 0: t = threading.Thread( target=_refresh_sessions_daemon, args=(sess_id_to_cli,), daemon=True, ) t.start() def _iter_one_table( self, input_path: str, worker_id: int = 0, num_workers: int = 1 ) -> Iterator[Dict[str, pa.Array]]: _, table_name, _ = _parse_table_path(input_path) client = self._table_to_cli[table_name] sess_reqs = self._input_to_sess[input_path] iterator = _reader_iter( client, sess_reqs, worker_id, num_workers, self._batch_size, self._drop_redundant_bs_eq_one, self._compression, ) yield from self._arrow_reader_iter(iterator) def to_batches( self, worker_id: int = 0, num_workers: int = 1 ) -> Iterator[Dict[str, pa.Array]]: """Get batch iterator.""" for input_path in self._input_path.split(","): yield from self._iter_one_table(input_path, worker_id, num_workers) class OdpsWriter(BaseWriter): """Odps(Maxcompute) writer class. Args: output_path (str): data output path. quota_name (str): storage api quota name. """ def __init__( self, output_path: str, quota_name: str = "pay-as-you-go", **kwargs: Any ) -> None: if int(os.environ.get("WORLD_SIZE", 1)) > 1: assert dist.is_initialized(), ( "You should initialize distribute group first." ) super().__init__(output_path) self._account, self._odps_endpoint = _create_odps_account() self._quota_name = quota_name os.environ["STORAGE_API_QUOTA_NAME"] = quota_name self._project, self._table_name, partitions = _parse_table_path(output_path) if partitions is None: self._partition_spec = None else: self._partition_spec = partitions[0] self._o = ODPS( account=self._account, project=self._project, endpoint=self._odps_endpoint, ) self._client = None self._sess_req = None self._writer = None def _create_table(self, output_dict: OrderedDict[str, pa.Array]) -> None: """Create output table.""" if not self._o.exist_table(self._table_name): schemas = [] for k, v in output_dict.items(): schemas.append(f"{k} {_type_pa_to_table(v.type)}") schema = ",".join(schemas) if self._partition_spec: pt_schemas = [] for pt_spec in self._partition_spec.split("/"): pt_name = pt_spec.split("=")[0] pt_schemas.append(f"{pt_name} STRING") schema = (schema, ",".join(pt_schemas)) self._o.create_table( self._table_name, schema, hints={"odps.sql.type.system.odps2": "true"} ) def _create_partition(self) -> None: """Create output partition.""" if self._partition_spec: t = self._o.get_table(self._table_name) partition_spec = self._partition_spec.replace("/", ",") if not t.exist_partition(partition_spec): t.create_partition(partition_spec, if_not_exists=True) def _init_writer(self) -> None: """Initialize table writer.""" self._client = StorageApiArrowClient( odps=self._o, table=self._o.get_table(self._table_name), quota_name=self._quota_name, ) session_id = None if int(os.environ.get("RANK", 0)) == 0: write_req = TableBatchWriteRequest( partition_spec=self._partition_spec, overwrite=True ) write_resp = self._client.create_write_session(write_req) session_id = write_resp.session_id if dist.is_initialized(): session_id = dist_util.broadcast_string(session_id) self._sess_req = SessionRequest(session_id=session_id) while True: sess_resp = self._client.get_write_session(self._sess_req) if sess_resp.session_status == SessionStatus.INIT: time.sleep(1) continue break row_req = WriteRowsRequest( session_id=sess_resp.session_id, block_number=int(os.environ.get("RANK", 0)) ) self._writer = self._client.write_rows_arrow(row_req) def _wait_init_table(self) -> None: """Wait table and partition ready.""" while True: if not self._o.exist_table(self._table_name): time.sleep(1) continue t = self._o.get_table(self._table_name) if self._partition_spec: partition_spec = self._partition_spec.replace("/", ",") if not t.exist_partition(partition_spec): time.sleep(1) continue break def write(self, output_dict: OrderedDict[str, pa.Array]) -> None: """Write a batch of data.""" if not self._lazy_inited: if int(os.environ.get("RANK", 0)) == 0: self._create_table(output_dict) self._create_partition() else: self._wait_init_table() self._init_writer() self._lazy_inited = True record_batch = pa.RecordBatch.from_arrays( list(output_dict.values()), list(output_dict.keys()), ) self._writer.write(record_batch) def close(self) -> None: """Close and commit data.""" if self._writer is not None: commit_msg, _ = self._writer.finish() if dist.is_initialized(): commit_msgs = dist_util.gather_strings(commit_msg) else: commit_msgs = [commit_msg] if int(os.environ.get("RANK", 0)) == 0: resp = self._client.commit_write_session(self._sess_req, commit_msgs) while resp.status == Status.WAIT: time.sleep(1) resp = self._client.get_write_session(self._sess_req) if resp.session_status != SessionStatus.COMMITTED: raise RuntimeError( f"Fail to commit write session: {self._sess_req.session_id}" ) super().close()