tzrec/datasets/dataset.py (431 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from collections import OrderedDict
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import numpy as np
import pyarrow as pa
from torch import distributed as dist
from torch.utils.data import IterableDataset, get_worker_info
from tzrec.constant import Mode
from tzrec.datasets.data_parser import DataParser
from tzrec.datasets.sampler import BaseSampler, TDMSampler
from tzrec.datasets.utils import (
C_NEG_SAMPLE_MASK,
C_SAMPLE_MASK,
Batch,
RecordBatchTensor,
process_hstu_neg_sample,
process_hstu_seq_data,
remove_nullable,
)
from tzrec.features.feature import BaseFeature
from tzrec.protos import data_pb2
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger
_DATASET_CLASS_MAP = {}
_READER_CLASS_MAP = {}
_WRITER_CLASS_MAP = {}
_dataset_meta_cls = get_register_class_meta(_DATASET_CLASS_MAP)
_reader_meta_cls = get_register_class_meta(_READER_CLASS_MAP)
_writer_meta_cls = get_register_class_meta(_WRITER_CLASS_MAP)
AVAILABLE_PA_TYPES = {
pa.int64(),
pa.float64(),
pa.float32(),
pa.string(),
pa.int32(),
pa.list_(pa.int64()),
pa.list_(pa.float64()),
pa.list_(pa.float32()),
pa.list_(pa.string()),
pa.list_(pa.int32()),
pa.list_(pa.list_(pa.int64())),
pa.list_(pa.list_(pa.float64())),
pa.list_(pa.list_(pa.float32())),
pa.list_(pa.list_(pa.string())),
pa.list_(pa.list_(pa.int32())),
pa.map_(pa.string(), pa.int64()),
pa.map_(pa.string(), pa.float64()),
pa.map_(pa.string(), pa.float32()),
pa.map_(pa.string(), pa.string()),
pa.map_(pa.string(), pa.int32()),
pa.map_(pa.int64(), pa.int64()),
pa.map_(pa.int64(), pa.float64()),
pa.map_(pa.int64(), pa.float32()),
pa.map_(pa.int64(), pa.string()),
pa.map_(pa.int64(), pa.int32()),
pa.map_(pa.int32(), pa.int64()),
pa.map_(pa.int32(), pa.float64()),
pa.map_(pa.int32(), pa.float32()),
pa.map_(pa.int32(), pa.string()),
pa.map_(pa.int32(), pa.int32()),
}
def _expand_tdm_sample(
input_data: Dict[str, pa.Array],
pos_sampled: Dict[str, pa.Array],
neg_sampled: Dict[str, pa.Array],
data_config: data_pb2.DataConfig,
) -> Dict[str, pa.Array]:
"""Expand input data with sampled data for tdm.
Combine the sampled positive and negative samples with the item
features, then expand the user features based on the original user-item
relationships, and supplement the corresponding labels according to the
positive and negative samples. Note that in the sampling results, the
sampled outcomes for each item are contiguous.
for example:
user_fea:[1, 2], item_fea:[0.1, 0.2], labels:[1,1],
pos_sample:[0.11, 0.12, 0.21, 0.22], neg_sample:[-0.11, -0.12, -0.21, -0.22]
concat item_fea:[0.1, 0.2, 0.11, 0.12, 0.21, 0.22, -0.11, -0.12, -0.21, -0.22]
duplicate user_fea and keep origin user-item
relationship: [1, 2, 1, 1, 2, 2, 1, 1, 2, 2]
expand label: [1, 1, 1, 1, 1, 1, 0, 0, 0, 0]
"""
item_fea_names = pos_sampled.keys()
all_fea_names = input_data.keys()
label_fields = set(data_config.label_fields)
user_fea_names = all_fea_names - item_fea_names - label_fields
for item_fea_name in item_fea_names:
input_data[item_fea_name] = pa.concat_arrays(
[
input_data[item_fea_name],
pos_sampled[item_fea_name],
neg_sampled[item_fea_name],
]
)
# In the sampling results, the sampled outcomes for each item are contiguous.
batch_size = len(input_data[list(label_fields)[0]])
num_pos_sampled = len(pos_sampled[list(item_fea_names)[0]])
num_neg_sampled = len(neg_sampled[list(item_fea_names)[0]])
user_pos_index = np.repeat(np.arange(batch_size), num_pos_sampled // batch_size)
user_neg_index = np.repeat(np.arange(batch_size), num_neg_sampled // batch_size)
for user_fea_name in user_fea_names:
user_fea = input_data[user_fea_name]
pos_expand_user_fea = user_fea.take(user_pos_index)
neg_expand_user_fea = user_fea.take(user_neg_index)
input_data[user_fea_name] = pa.concat_arrays(
[
input_data[user_fea_name],
pos_expand_user_fea,
neg_expand_user_fea,
]
)
for label_field in label_fields:
input_data[label_field] = pa.concat_arrays(
[
input_data[label_field].cast(pa.int64()),
pa.array([1] * num_pos_sampled, type=pa.int64()),
pa.array([0] * num_neg_sampled, type=pa.int64()),
]
)
return input_data
class BaseDataset(IterableDataset, metaclass=_dataset_meta_cls):
"""Dataset base class.
Args:
data_config (DataConfig): an instance of DataConfig.
features (list): list of features.
input_path (str): data input path.
reserved_columns (list): reserved columns in predict mode.
mode (Mode): train or eval or predict.
debug_level (int): dataset debug level, when mode=predict and
debug_level > 0, will dump fg encoded data to debug_str
"""
def __init__(
self,
data_config: data_pb2.DataConfig,
features: List[BaseFeature],
input_path: str,
reserved_columns: Optional[List[str]] = None,
mode: Mode = Mode.EVAL,
debug_level: int = 0,
) -> None:
super(BaseDataset, self).__init__()
self._data_config = data_config
self._features = features
self._input_path = input_path
self._reserved_columns = reserved_columns or []
self._mode = mode
self._debug_level = debug_level
self._enable_hstu = data_config.enable_hstu
self._data_parser = DataParser(
features=features,
labels=list(data_config.label_fields)
if self._mode != Mode.PREDICT
else None,
sample_weights=list(data_config.sample_weight_fields)
if self._mode != Mode.PREDICT
else None,
is_training=self._mode == Mode.TRAIN,
fg_threads=data_config.fg_threads,
force_base_data_group=data_config.force_base_data_group,
)
self._input_fields = None
self._selected_input_names = set()
self._selected_input_names |= self._data_parser.feature_input_names
if self._mode == Mode.PREDICT:
self._selected_input_names |= set(self._reserved_columns)
else:
self._selected_input_names |= set(data_config.label_fields)
self._selected_input_names |= set(data_config.sample_weight_fields)
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)
if hasattr(sampler_config, "item_id_field") and sampler_config.HasField(
"item_id_field"
):
self._selected_input_names.add(sampler_config.item_id_field)
if hasattr(sampler_config, "user_id_field") and sampler_config.HasField(
"user_id_field"
):
self._selected_input_names.add(sampler_config.user_id_field)
# if set selected_input_names to None,
# all columns will be reserved.
if (
len(self._reserved_columns) > 0
and self._reserved_columns[0] == "ALL_COLUMNS"
):
self._selected_input_names = None
self._fg_mode = data_config.fg_mode
self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep
if mode != Mode.TRAIN and data_config.HasField("eval_batch_size"):
self._batch_size = data_config.eval_batch_size
else:
self._batch_size = data_config.batch_size
self._sampler = None
self._sampler_inited = False
self._reader = None
def launch_sampler_cluster(
self,
num_client_per_rank: int = 1,
client_id_bias: int = 0,
cluster: Optional[Dict[str, Union[int, str]]] = None,
) -> None:
"""Launch sampler cluster and server."""
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
sampler_type = self._data_config.WhichOneof("sampler")
sampler_config = getattr(self._data_config, sampler_type)
# pyre-ignore [16]
self._sampler = BaseSampler.create_class(sampler_config.__class__.__name__)(
sampler_config,
self.input_fields,
self._batch_size,
is_training=self._mode == Mode.TRAIN,
multival_sep=self._fg_encoded_multival_sep
if self._fg_mode == data_pb2.FgMode.FG_NONE
else chr(29),
)
self._sampler.init_cluster(num_client_per_rank, client_id_bias, cluster)
if cluster is None:
self._sampler.launch_server()
def get_sampler_cluster(self) -> Optional[Dict[str, Union[int, str]]]:
"""Get sampler cluster."""
if self._sampler:
return self._sampler._cluster
def _init_input_fields(self) -> None:
"""Init input fields info."""
self._input_fields = []
for field in self._reader.schema:
field_type = remove_nullable(field.type)
if any(map(lambda x: x == field_type, AVAILABLE_PA_TYPES)):
self._input_fields.append(field)
else:
raise ValueError(
f"column [{field.name}] with dtype {field.type} "
"is not supported now."
)
@property
def input_fields(self) -> List[pa.Field]:
"""Input fields info, overwrote by subclass for auto infer the info."""
if not self._input_fields:
self._input_fields = list(self._data_config.input_fields)
return self._input_fields
def get_worker_info(self) -> Tuple[int, int]:
"""Get multiprocessing dataloader worker id and worker number."""
worker_info = get_worker_info()
if worker_info is None:
worker_id = 0
num_workers = 1
else:
worker_id = worker_info.id
num_workers = worker_info.num_workers
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
else:
rank = 0
world_size = 1
return rank * num_workers + worker_id, num_workers * world_size
def __iter__(self) -> Iterator[Batch]:
if self._sampler is not None and not self._sampler_inited:
self._sampler.init()
self._sampler_inited = True
worker_id, num_workers = self.get_worker_info()
for input_data in self._reader.to_batches(worker_id, num_workers):
yield self._build_batch(input_data)
def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
"""Process input data and build batch.
Args:
input_data (dict): raw input data.
Returns:
an instance of Batch.
"""
use_sample_mask = self._mode == Mode.TRAIN and (
self._data_config.negative_sample_mask_prob > 0
or self._data_config.sample_mask_prob > 0
)
if use_sample_mask:
input_data[C_SAMPLE_MASK] = pa.array(
np.random.random(len(list(input_data.values())[0]))
< self._data_config.sample_mask_prob
)
if self._sampler is not None:
if isinstance(self._sampler, TDMSampler):
pos_sampled, neg_sampled = self._sampler.get(input_data)
input_data = _expand_tdm_sample(
input_data, pos_sampled, neg_sampled, self._data_config
)
elif self._enable_hstu:
seq_attr = self._sampler._item_id_field
(
input_data_k_split,
input_data_k_split_slice,
pre_seq_filter_reshaped_joined,
) = process_hstu_seq_data(
input_data=input_data,
seq_attr=seq_attr,
seq_str_delim=self._sampler.item_id_delim,
)
if self._mode == Mode.TRAIN:
# Training using all possible target items
input_data[seq_attr] = input_data_k_split_slice
elif self._mode == Mode.EVAL:
# Evaluation using the last item for previous sequence
input_data[seq_attr] = input_data_k_split.values.take(
pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1)
)
sampled = self._sampler.get(input_data)
# To keep consistent with other process, use two functions
for k, v in sampled.items():
if k in input_data:
combined = process_hstu_neg_sample(
input_data,
v,
self._sampler._num_sample,
self._sampler.item_id_delim,
seq_attr,
)
# Combine here to make embddings of both user sequence
# and target item are the same
input_data[k] = pa.concat_arrays(
[pre_seq_filter_reshaped_joined, combined]
)
else:
input_data[k] = v
else:
sampled = self._sampler.get(input_data)
for k, v in sampled.items():
if k in input_data:
input_data[k] = pa.concat_arrays([input_data[k], v])
else:
input_data[k] = v
if use_sample_mask:
input_data[C_NEG_SAMPLE_MASK] = pa.concat_arrays(
[
input_data[C_SAMPLE_MASK],
pa.array(
np.random.random(len(list(sampled.values())[0]))
< self._data_config.negative_sample_mask_prob
),
]
)
# TODO(hongsheng.jhs): add additional field like hard_negative
output_data = self._data_parser.parse(input_data)
if self._mode == Mode.PREDICT:
batch = self._data_parser.to_batch(output_data, force_no_tile=True)
reserved_data = {}
if (
len(self._reserved_columns) > 0
and self._reserved_columns[0] == "ALL_COLUMNS"
):
reserved_data = input_data
else:
for k in self._reserved_columns:
reserved_data[k] = input_data[k]
if self._debug_level > 0:
reserved_data["__features__"] = self._data_parser.dump_parsed_inputs(
output_data
)
if len(reserved_data) > 0:
batch.reserves = RecordBatchTensor(pa.record_batch(reserved_data))
else:
batch = self._data_parser.to_batch(output_data)
return batch
@property
def sampled_batch_size(self) -> int:
"""Batch size with sampler."""
if self._sampler:
return self._batch_size + self._sampler.estimated_sample_num
else:
return self._batch_size
class BaseReader(metaclass=_reader_meta_cls):
"""Reader base class.
Args:
input_path (str): data input path.
batch_size (int): batch size.
selected_cols (list): selection column names.
drop_remainder (bool): drop last batch.
shuffle (bool): shuffle data or not.
shuffle_buffer_size (int): buffer size for shuffle.
"""
def __init__(
self,
input_path: str,
batch_size: int,
selected_cols: Optional[List[str]] = None,
drop_remainder: bool = False,
shuffle: bool = False,
shuffle_buffer_size: int = 32,
**kwargs: Any,
) -> None:
self._input_path = input_path
self._batch_size = batch_size
self._selected_cols = selected_cols
self._drop_remainder = drop_remainder
self._shuffle = shuffle
self._shuffle_buffer_size = shuffle_buffer_size
def to_batches(
self, worker_id: int = 0, num_workers: int = 1
) -> Iterator[Dict[str, pa.Array]]:
"""Get batch iterator."""
raise NotImplementedError
def _arrow_reader_iter(
self, reader: Iterator[pa.RecordBatch]
) -> Iterator[Dict[str, pa.Array]]:
shuffle_buffer = []
buff_data = None
while True:
data = None
if buff_data is None or len(buff_data) < self._batch_size:
try:
read_data = next(reader)
if buff_data is None:
buff_data = pa.Table.from_batches([read_data])
else:
buff_data = pa.concat_tables(
[buff_data, pa.Table.from_batches([read_data])]
)
except StopIteration:
data = None if self._drop_remainder else buff_data
buff_data = None
elif len(buff_data) == self._batch_size:
data = buff_data
buff_data = None
else:
data = buff_data.slice(0, self._batch_size)
buff_data = buff_data.slice(self._batch_size)
if data is not None:
data_dict = {}
for name, column in zip(data.column_names, data.columns):
if isinstance(column, pa.ChunkedArray):
column = column.combine_chunks()
data_dict[name] = column
if self._shuffle:
shuffle_buffer.append(data_dict)
if len(shuffle_buffer) < self._shuffle_buffer_size:
continue
else:
idx = random.randrange(len(shuffle_buffer))
data_dict = shuffle_buffer.pop(idx)
yield data_dict
if data is None and buff_data is None:
break
if len(shuffle_buffer) > 0:
random.shuffle(shuffle_buffer)
for data_dict in shuffle_buffer:
yield data_dict
class BaseWriter(metaclass=_writer_meta_cls):
"""Writer base class.
Args:
output_path (str): data output path.
"""
def __init__(self, output_path: str, **kwargs: Any) -> None:
self._lazy_inited = False
self._output_path = output_path
def write(self, output_dict: OrderedDict[str, pa.Array]) -> None:
"""Write a batch of data."""
raise NotImplementedError
def close(self) -> None:
"""Close and commit data."""
self._lazy_inited = False
def __del__(self) -> None:
if self._lazy_inited:
# pyre-ignore [16]
logger.warning(f"You should close {self.__class__.__name__} explicitly.")
def create_reader(
input_path: str,
batch_size: int,
selected_cols: Optional[List[str]] = None,
reader_type: Optional[str] = None,
**kwargs: Any,
) -> BaseReader:
"""Create data reader.
Args:
input_path (str): data input path.
batch_size (int): batch size.
selected_cols (list): selection column names.
reader_type (str, optional): specify the input path reader type, if we cannot
infer from input_path.
**kwargs: additional params.
Returns:
reader: a data reader.
"""
if input_path.startswith("odps://"):
reader_cls_name = "OdpsReader"
elif input_path.endswith(".csv"):
reader_cls_name = "CsvReader"
elif input_path.endswith(".parquet"):
reader_cls_name = "ParquetReader"
else:
assert reader_type is not None, "You should set reader_type."
reader_cls_name = reader_type
# pyre-ignore [16]
reader = BaseReader.create_class(reader_cls_name)(
input_path=input_path,
batch_size=batch_size,
selected_cols=selected_cols,
**kwargs,
)
return reader
def create_writer(
output_path: str, writer_type: Optional[str] = None, **kwargs: Any
) -> BaseWriter:
"""Create data writer.
Args:
output_path (str): data output path.
writer_type (str, optional): specify the input path writer type, if we cannot
infer from input_path.
**kwargs: additional params.
Returns:
writer: a data writer.
"""
if output_path.startswith("odps://"):
writer_cls_name = "OdpsWriter"
else:
assert writer_type is not None, "You should set writer_type."
writer_cls_name = writer_type
# pyre-ignore [16]
writer = BaseWriter.create_class(writer_cls_name)(output_path=output_path, **kwargs)
return writer