tzrec/datasets/dataset.py (431 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import random from collections import OrderedDict from typing import Any, Dict, Iterator, List, Optional, Tuple, Union import numpy as np import pyarrow as pa from torch import distributed as dist from torch.utils.data import IterableDataset, get_worker_info from tzrec.constant import Mode from tzrec.datasets.data_parser import DataParser from tzrec.datasets.sampler import BaseSampler, TDMSampler from tzrec.datasets.utils import ( C_NEG_SAMPLE_MASK, C_SAMPLE_MASK, Batch, RecordBatchTensor, process_hstu_neg_sample, process_hstu_seq_data, remove_nullable, ) from tzrec.features.feature import BaseFeature from tzrec.protos import data_pb2 from tzrec.utils.load_class import get_register_class_meta from tzrec.utils.logging_util import logger _DATASET_CLASS_MAP = {} _READER_CLASS_MAP = {} _WRITER_CLASS_MAP = {} _dataset_meta_cls = get_register_class_meta(_DATASET_CLASS_MAP) _reader_meta_cls = get_register_class_meta(_READER_CLASS_MAP) _writer_meta_cls = get_register_class_meta(_WRITER_CLASS_MAP) AVAILABLE_PA_TYPES = { pa.int64(), pa.float64(), pa.float32(), pa.string(), pa.int32(), pa.list_(pa.int64()), pa.list_(pa.float64()), pa.list_(pa.float32()), pa.list_(pa.string()), pa.list_(pa.int32()), pa.list_(pa.list_(pa.int64())), pa.list_(pa.list_(pa.float64())), pa.list_(pa.list_(pa.float32())), pa.list_(pa.list_(pa.string())), pa.list_(pa.list_(pa.int32())), pa.map_(pa.string(), pa.int64()), pa.map_(pa.string(), pa.float64()), pa.map_(pa.string(), pa.float32()), pa.map_(pa.string(), pa.string()), pa.map_(pa.string(), pa.int32()), pa.map_(pa.int64(), pa.int64()), pa.map_(pa.int64(), pa.float64()), pa.map_(pa.int64(), pa.float32()), pa.map_(pa.int64(), pa.string()), pa.map_(pa.int64(), pa.int32()), pa.map_(pa.int32(), pa.int64()), pa.map_(pa.int32(), pa.float64()), pa.map_(pa.int32(), pa.float32()), pa.map_(pa.int32(), pa.string()), pa.map_(pa.int32(), pa.int32()), } def _expand_tdm_sample( input_data: Dict[str, pa.Array], pos_sampled: Dict[str, pa.Array], neg_sampled: Dict[str, pa.Array], data_config: data_pb2.DataConfig, ) -> Dict[str, pa.Array]: """Expand input data with sampled data for tdm. Combine the sampled positive and negative samples with the item features, then expand the user features based on the original user-item relationships, and supplement the corresponding labels according to the positive and negative samples. Note that in the sampling results, the sampled outcomes for each item are contiguous. for example: user_fea:[1, 2], item_fea:[0.1, 0.2], labels:[1,1], pos_sample:[0.11, 0.12, 0.21, 0.22], neg_sample:[-0.11, -0.12, -0.21, -0.22] concat item_fea:[0.1, 0.2, 0.11, 0.12, 0.21, 0.22, -0.11, -0.12, -0.21, -0.22] duplicate user_fea and keep origin user-item relationship: [1, 2, 1, 1, 2, 2, 1, 1, 2, 2] expand label: [1, 1, 1, 1, 1, 1, 0, 0, 0, 0] """ item_fea_names = pos_sampled.keys() all_fea_names = input_data.keys() label_fields = set(data_config.label_fields) user_fea_names = all_fea_names - item_fea_names - label_fields for item_fea_name in item_fea_names: input_data[item_fea_name] = pa.concat_arrays( [ input_data[item_fea_name], pos_sampled[item_fea_name], neg_sampled[item_fea_name], ] ) # In the sampling results, the sampled outcomes for each item are contiguous. batch_size = len(input_data[list(label_fields)[0]]) num_pos_sampled = len(pos_sampled[list(item_fea_names)[0]]) num_neg_sampled = len(neg_sampled[list(item_fea_names)[0]]) user_pos_index = np.repeat(np.arange(batch_size), num_pos_sampled // batch_size) user_neg_index = np.repeat(np.arange(batch_size), num_neg_sampled // batch_size) for user_fea_name in user_fea_names: user_fea = input_data[user_fea_name] pos_expand_user_fea = user_fea.take(user_pos_index) neg_expand_user_fea = user_fea.take(user_neg_index) input_data[user_fea_name] = pa.concat_arrays( [ input_data[user_fea_name], pos_expand_user_fea, neg_expand_user_fea, ] ) for label_field in label_fields: input_data[label_field] = pa.concat_arrays( [ input_data[label_field].cast(pa.int64()), pa.array([1] * num_pos_sampled, type=pa.int64()), pa.array([0] * num_neg_sampled, type=pa.int64()), ] ) return input_data class BaseDataset(IterableDataset, metaclass=_dataset_meta_cls): """Dataset base class. Args: data_config (DataConfig): an instance of DataConfig. features (list): list of features. input_path (str): data input path. reserved_columns (list): reserved columns in predict mode. mode (Mode): train or eval or predict. debug_level (int): dataset debug level, when mode=predict and debug_level > 0, will dump fg encoded data to debug_str """ def __init__( self, data_config: data_pb2.DataConfig, features: List[BaseFeature], input_path: str, reserved_columns: Optional[List[str]] = None, mode: Mode = Mode.EVAL, debug_level: int = 0, ) -> None: super(BaseDataset, self).__init__() self._data_config = data_config self._features = features self._input_path = input_path self._reserved_columns = reserved_columns or [] self._mode = mode self._debug_level = debug_level self._enable_hstu = data_config.enable_hstu self._data_parser = DataParser( features=features, labels=list(data_config.label_fields) if self._mode != Mode.PREDICT else None, sample_weights=list(data_config.sample_weight_fields) if self._mode != Mode.PREDICT else None, is_training=self._mode == Mode.TRAIN, fg_threads=data_config.fg_threads, force_base_data_group=data_config.force_base_data_group, ) self._input_fields = None self._selected_input_names = set() self._selected_input_names |= self._data_parser.feature_input_names if self._mode == Mode.PREDICT: self._selected_input_names |= set(self._reserved_columns) else: self._selected_input_names |= set(data_config.label_fields) self._selected_input_names |= set(data_config.sample_weight_fields) if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) if hasattr(sampler_config, "item_id_field") and sampler_config.HasField( "item_id_field" ): self._selected_input_names.add(sampler_config.item_id_field) if hasattr(sampler_config, "user_id_field") and sampler_config.HasField( "user_id_field" ): self._selected_input_names.add(sampler_config.user_id_field) # if set selected_input_names to None, # all columns will be reserved. if ( len(self._reserved_columns) > 0 and self._reserved_columns[0] == "ALL_COLUMNS" ): self._selected_input_names = None self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep if mode != Mode.TRAIN and data_config.HasField("eval_batch_size"): self._batch_size = data_config.eval_batch_size else: self._batch_size = data_config.batch_size self._sampler = None self._sampler_inited = False self._reader = None def launch_sampler_cluster( self, num_client_per_rank: int = 1, client_id_bias: int = 0, cluster: Optional[Dict[str, Union[int, str]]] = None, ) -> None: """Launch sampler cluster and server.""" if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: sampler_type = self._data_config.WhichOneof("sampler") sampler_config = getattr(self._data_config, sampler_type) # pyre-ignore [16] self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)( sampler_config, self.input_fields, self._batch_size, is_training=self._mode == Mode.TRAIN, multival_sep=self._fg_encoded_multival_sep if self._fg_mode == data_pb2.FgMode.FG_NONE else chr(29), ) self._sampler.init_cluster(num_client_per_rank, client_id_bias, cluster) if cluster is None: self._sampler.launch_server() def get_sampler_cluster(self) -> Optional[Dict[str, Union[int, str]]]: """Get sampler cluster.""" if self._sampler: return self._sampler._cluster def _init_input_fields(self) -> None: """Init input fields info.""" self._input_fields = [] for field in self._reader.schema: field_type = remove_nullable(field.type) if any(map(lambda x: x == field_type, AVAILABLE_PA_TYPES)): self._input_fields.append(field) else: raise ValueError( f"column [{field.name}] with dtype {field.type} " "is not supported now." ) @property def input_fields(self) -> List[pa.Field]: """Input fields info, overwrote by subclass for auto infer the info.""" if not self._input_fields: self._input_fields = list(self._data_config.input_fields) return self._input_fields def get_worker_info(self) -> Tuple[int, int]: """Get multiprocessing dataloader worker id and worker number.""" worker_info = get_worker_info() if worker_info is None: worker_id = 0 num_workers = 1 else: worker_id = worker_info.id num_workers = worker_info.num_workers if dist.is_initialized(): rank = dist.get_rank() world_size = dist.get_world_size() else: rank = 0 world_size = 1 return rank * num_workers + worker_id, num_workers * world_size def __iter__(self) -> Iterator[Batch]: if self._sampler is not None and not self._sampler_inited: self._sampler.init() self._sampler_inited = True worker_id, num_workers = self.get_worker_info() for input_data in self._reader.to_batches(worker_id, num_workers): yield self._build_batch(input_data) def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch: """Process input data and build batch. Args: input_data (dict): raw input data. Returns: an instance of Batch. """ use_sample_mask = self._mode == Mode.TRAIN and ( self._data_config.negative_sample_mask_prob > 0 or self._data_config.sample_mask_prob > 0 ) if use_sample_mask: input_data[C_SAMPLE_MASK] = pa.array( np.random.random(len(list(input_data.values())[0])) < self._data_config.sample_mask_prob ) if self._sampler is not None: if isinstance(self._sampler, TDMSampler): pos_sampled, neg_sampled = self._sampler.get(input_data) input_data = _expand_tdm_sample( input_data, pos_sampled, neg_sampled, self._data_config ) elif self._enable_hstu: seq_attr = self._sampler._item_id_field ( input_data_k_split, input_data_k_split_slice, pre_seq_filter_reshaped_joined, ) = process_hstu_seq_data( input_data=input_data, seq_attr=seq_attr, seq_str_delim=self._sampler.item_id_delim, ) if self._mode == Mode.TRAIN: # Training using all possible target items input_data[seq_attr] = input_data_k_split_slice elif self._mode == Mode.EVAL: # Evaluation using the last item for previous sequence input_data[seq_attr] = input_data_k_split.values.take( pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1) ) sampled = self._sampler.get(input_data) # To keep consistent with other process, use two functions for k, v in sampled.items(): if k in input_data: combined = process_hstu_neg_sample( input_data, v, self._sampler._num_sample, self._sampler.item_id_delim, seq_attr, ) # Combine here to make embddings of both user sequence # and target item are the same input_data[k] = pa.concat_arrays( [pre_seq_filter_reshaped_joined, combined] ) else: input_data[k] = v else: sampled = self._sampler.get(input_data) for k, v in sampled.items(): if k in input_data: input_data[k] = pa.concat_arrays([input_data[k], v]) else: input_data[k] = v if use_sample_mask: input_data[C_NEG_SAMPLE_MASK] = pa.concat_arrays( [ input_data[C_SAMPLE_MASK], pa.array( np.random.random(len(list(sampled.values())[0])) < self._data_config.negative_sample_mask_prob ), ] ) # TODO(hongsheng.jhs): add additional field like hard_negative output_data = self._data_parser.parse(input_data) if self._mode == Mode.PREDICT: batch = self._data_parser.to_batch(output_data, force_no_tile=True) reserved_data = {} if ( len(self._reserved_columns) > 0 and self._reserved_columns[0] == "ALL_COLUMNS" ): reserved_data = input_data else: for k in self._reserved_columns: reserved_data[k] = input_data[k] if self._debug_level > 0: reserved_data["__features__"] = self._data_parser.dump_parsed_inputs( output_data ) if len(reserved_data) > 0: batch.reserves = RecordBatchTensor(pa.record_batch(reserved_data)) else: batch = self._data_parser.to_batch(output_data) return batch @property def sampled_batch_size(self) -> int: """Batch size with sampler.""" if self._sampler: return self._batch_size + self._sampler.estimated_sample_num else: return self._batch_size class BaseReader(metaclass=_reader_meta_cls): """Reader base class. Args: input_path (str): data input path. batch_size (int): batch size. selected_cols (list): selection column names. drop_remainder (bool): drop last batch. shuffle (bool): shuffle data or not. shuffle_buffer_size (int): buffer size for shuffle. """ def __init__( self, input_path: str, batch_size: int, selected_cols: Optional[List[str]] = None, drop_remainder: bool = False, shuffle: bool = False, shuffle_buffer_size: int = 32, **kwargs: Any, ) -> None: self._input_path = input_path self._batch_size = batch_size self._selected_cols = selected_cols self._drop_remainder = drop_remainder self._shuffle = shuffle self._shuffle_buffer_size = shuffle_buffer_size def to_batches( self, worker_id: int = 0, num_workers: int = 1 ) -> Iterator[Dict[str, pa.Array]]: """Get batch iterator.""" raise NotImplementedError def _arrow_reader_iter( self, reader: Iterator[pa.RecordBatch] ) -> Iterator[Dict[str, pa.Array]]: shuffle_buffer = [] buff_data = None while True: data = None if buff_data is None or len(buff_data) < self._batch_size: try: read_data = next(reader) if buff_data is None: buff_data = pa.Table.from_batches([read_data]) else: buff_data = pa.concat_tables( [buff_data, pa.Table.from_batches([read_data])] ) except StopIteration: data = None if self._drop_remainder else buff_data buff_data = None elif len(buff_data) == self._batch_size: data = buff_data buff_data = None else: data = buff_data.slice(0, self._batch_size) buff_data = buff_data.slice(self._batch_size) if data is not None: data_dict = {} for name, column in zip(data.column_names, data.columns): if isinstance(column, pa.ChunkedArray): column = column.combine_chunks() data_dict[name] = column if self._shuffle: shuffle_buffer.append(data_dict) if len(shuffle_buffer) < self._shuffle_buffer_size: continue else: idx = random.randrange(len(shuffle_buffer)) data_dict = shuffle_buffer.pop(idx) yield data_dict if data is None and buff_data is None: break if len(shuffle_buffer) > 0: random.shuffle(shuffle_buffer) for data_dict in shuffle_buffer: yield data_dict class BaseWriter(metaclass=_writer_meta_cls): """Writer base class. Args: output_path (str): data output path. """ def __init__(self, output_path: str, **kwargs: Any) -> None: self._lazy_inited = False self._output_path = output_path def write(self, output_dict: OrderedDict[str, pa.Array]) -> None: """Write a batch of data.""" raise NotImplementedError def close(self) -> None: """Close and commit data.""" self._lazy_inited = False def __del__(self) -> None: if self._lazy_inited: # pyre-ignore [16] logger.warning(f"You should close {self.__class__.__name__} explicitly.") def create_reader( input_path: str, batch_size: int, selected_cols: Optional[List[str]] = None, reader_type: Optional[str] = None, **kwargs: Any, ) -> BaseReader: """Create data reader. Args: input_path (str): data input path. batch_size (int): batch size. selected_cols (list): selection column names. reader_type (str, optional): specify the input path reader type, if we cannot infer from input_path. **kwargs: additional params. Returns: reader: a data reader. """ if input_path.startswith("odps://"): reader_cls_name = "OdpsReader" elif input_path.endswith(".csv"): reader_cls_name = "CsvReader" elif input_path.endswith(".parquet"): reader_cls_name = "ParquetReader" else: assert reader_type is not None, "You should set reader_type." reader_cls_name = reader_type # pyre-ignore [16] reader = BaseReader.create_class(reader_cls_name)( input_path=input_path, batch_size=batch_size, selected_cols=selected_cols, **kwargs, ) return reader def create_writer( output_path: str, writer_type: Optional[str] = None, **kwargs: Any ) -> BaseWriter: """Create data writer. Args: output_path (str): data output path. writer_type (str, optional): specify the input path writer type, if we cannot infer from input_path. **kwargs: additional params. Returns: writer: a data writer. """ if output_path.startswith("odps://"): writer_cls_name = "OdpsWriter" else: assert writer_type is not None, "You should set writer_type." writer_cls_name = writer_type # pyre-ignore [16] writer = BaseWriter.create_class(writer_cls_name)(output_path=output_path, **kwargs) return writer