tzrec/datasets/utils.py (291 lines of code) (raw):
# Copyright (c) 2024, Alibaba Group;
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple
import numpy as np
import numpy.typing as npt
import pyarrow as pa
import pyarrow.compute as pc
import torch
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
from torchrec.streamable import Pipelineable
from tzrec.protos.data_pb2 import FieldType
BASE_DATA_GROUP = "__BASE__"
NEG_DATA_GROUP = "__NEG__"
CROSS_NEG_DATA_GROUP = "__CNEG__"
C_SAMPLE_MASK = "__SAMPLE_MASK__"
C_NEG_SAMPLE_MASK = "__NEG_SAMPLE_MASK__"
FIELD_TYPE_TO_PA = {
FieldType.INT32: pa.int32(),
FieldType.INT64: pa.int64(),
FieldType.FLOAT: pa.float32(),
FieldType.DOUBLE: pa.float64(),
FieldType.STRING: pa.string(),
}
@dataclass
class ParsedData:
"""Internal parsed data structure."""
name: str
@dataclass
class SparseData(ParsedData):
"""Internal data structure for sparse feature."""
values: npt.NDArray
lengths: npt.NDArray
weights: Optional[npt.NDArray] = None
@dataclass
class DenseData(ParsedData):
"""Internal data structure for dense feature."""
values: npt.NDArray
@dataclass
class SequenceSparseData(ParsedData):
"""Internal data structure for sequence sparse feature."""
values: npt.NDArray
key_lengths: npt.NDArray
seq_lengths: npt.NDArray
@dataclass
class SequenceDenseData(ParsedData):
"""Internal data structure for sequence dense feature."""
values: npt.NDArray
seq_lengths: npt.NDArray
class RecordBatchTensor:
"""PyArrow RecordBatch use Tensor as buffer.
For efficient transfer data between processes, e.g., mp.Queue.
"""
def __init__(self, record_batch: Optional[pa.RecordBatch] = None) -> None:
self._schema = None
self._buff = None
if record_batch:
self._schema = record_batch.schema
self._buff = torch.UntypedStorage.from_buffer(
record_batch.serialize(), dtype=torch.uint8
)
def get(self) -> Optional[pa.RecordBatch]:
"""Get RecordBatch."""
if self._buff is not None:
# pyre-ignore[16]
return pa.ipc.read_record_batch(
pa.foreign_buffer(self._buff.data_ptr(), self._buff.size()),
self._schema,
)
else:
return None
@dataclass
class Batch(Pipelineable):
"""Input Batch."""
# key of dense_features is data group name
dense_features: Dict[str, KeyedTensor] = field(default_factory=dict)
# key of sparse_features is data group name
sparse_features: Dict[str, KeyedJaggedTensor] = field(default_factory=dict)
# key of sequence_mulval_lengths is data group name
#
# for multi-value sequence, we flatten it, then store values & accumate lengths
# into sparse_features, store key_lengths & seq_lengths into sequence_mulval_lengths
#
# e.g.
# for the sequence `click_seq`: [[[3, 4], [5]], [6, [7, 8]]]
# we can denote it in jagged formular with:
# values: [3, 4, 5, 6, 7, 8]
# key_lengths: [2, 1, 1, 2]
# seq_lengths: [2, 2]
# then:
# sparse_features[dg]['click_seq'].values() = [3, 4, 5, 6, 7, 8] # values
# sparse_features[dg]['click_seq'].lengths() = [3, 3] # accumate lengths
# sequence_mulval_lengths[dg]['click_seq'].values() = [2, 1, 1, 2] # key_lengths
# sequence_mulval_lengths[dg]['click_seq'].lengths() = [2, 2] # seq_lengths
sequence_mulval_lengths: Dict[str, KeyedJaggedTensor] = field(default_factory=dict)
# key of sequence_dense_features is feature name
sequence_dense_features: Dict[str, JaggedTensor] = field(default_factory=dict)
# key of labels is label name
labels: Dict[str, torch.Tensor] = field(default_factory=dict)
# reserved inputs [for predict]
reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor)
# size for user side input tile when do inference and INPUT_TILE=2 or 3
tile_size: int = field(default=-1)
# sample_weight
sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict)
def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
"""Copy to specified device."""
return Batch(
dense_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.dense_features.items()
},
sparse_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sparse_features.items()
},
sequence_mulval_lengths={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sequence_mulval_lengths.items()
},
sequence_dense_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sequence_dense_features.items()
},
labels={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.labels.items()
},
reserves=self.reserves,
tile_size=self.tile_size,
sample_weights={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sample_weights.items()
},
)
def record_stream(self, stream: torch.Stream) -> None:
"""Record which streams have used the tensor."""
for v in self.dense_features.values():
v.record_stream(stream)
for v in self.sparse_features.values():
v.record_stream(stream)
for v in self.sequence_mulval_lengths.values():
v.record_stream(stream)
for v in self.sequence_dense_features.values():
v.record_stream(stream)
for v in self.labels.values():
v.record_stream(stream)
for v in self.sample_weights.values():
v.record_stream(stream)
def pin_memory(self) -> "Batch":
"""Copy to pinned memory."""
# TODO(hongsheng.jhs): KeyedTensor do not have pin_memory()
dense_features = {}
for k, v in self.dense_features.items():
dense_features[k] = KeyedTensor(
keys=v.keys(),
length_per_key=v.length_per_key(),
values=v.values().pin_memory(),
key_dim=v.key_dim(),
)
sequence_dense_features = {}
for k, v in self.sequence_dense_features.items():
weights = v._weights
lengths = v._lengths
offsets = v._offsets
sequence_dense_features[k] = JaggedTensor(
values=v.values().pin_memory(),
weights=weights.pin_memory() if weights is not None else None,
lengths=lengths.pin_memory() if lengths is not None else None,
offsets=offsets.pin_memory() if offsets is not None else None,
)
return Batch(
dense_features=dense_features,
sparse_features={
k: v.pin_memory() for k, v in self.sparse_features.items()
},
sequence_mulval_lengths={
k: v.pin_memory() for k, v in self.sequence_mulval_lengths.items()
},
sequence_dense_features=sequence_dense_features,
labels={k: v.pin_memory() for k, v in self.labels.items()},
reserves=self.reserves,
tile_size=self.tile_size,
sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()},
)
def to_dict(
self, sparse_dtype: Optional[torch.dtype] = None
) -> Dict[str, torch.Tensor]:
"""Convert to feature tensor dict."""
tensor_dict = {}
for x in self.dense_features.values():
for k, v in x.to_dict().items():
tensor_dict[f"{k}.values"] = v
for x in self.sparse_features.values():
if sparse_dtype:
x = KeyedJaggedTensor(
keys=x.keys(),
values=x.values().to(sparse_dtype),
lengths=x.lengths().to(sparse_dtype),
weights=x.weights_or_none(),
)
for k, v in x.to_dict().items():
tensor_dict[f"{k}.values"] = v.values()
tensor_dict[f"{k}.lengths"] = v.lengths()
if v.weights_or_none() is not None:
tensor_dict[f"{k}.weights"] = v.weights()
for x in self.sequence_mulval_lengths.values():
if sparse_dtype:
x = KeyedJaggedTensor(
keys=x.keys(),
values=x.values().to(sparse_dtype),
lengths=x.lengths().to(sparse_dtype),
)
for k, v in x.to_dict().items():
tensor_dict[f"{k}.key_lengths"] = v.values()
tensor_dict[f"{k}.lengths"] = v.lengths()
for k, v in self.sequence_dense_features.items():
tensor_dict[f"{k}.values"] = v.values()
tensor_dict[f"{k}.lengths"] = v.lengths()
for k, v in self.labels.items():
tensor_dict[f"{k}"] = v
for k, v in self.sample_weights.items():
tensor_dict[f"{k}"] = v
if self.tile_size > 0:
tensor_dict["batch_size"] = torch.tensor(self.tile_size, dtype=torch.int64)
return tensor_dict
def process_hstu_seq_data(
input_data: Dict[str, pa.Array],
seq_attr: str,
seq_str_delim: str,
) -> Tuple[pa.Array, pa.Array, pa.Array]:
"""Process sequence data for HSTU match model.
Args:
input_data: Dictionary containing input arrays
seq_attr: Name of the sequence attribute field
seq_str_delim: Delimiter used to separate sequence items
Returns:
Tuple containing:
- input_data_k_split: pa.Array, Original sequence items
- input_data_k_split_slice: pa.Array, Target items for autoregressive training
- pre_seq_filter_reshaped_joined: pa.Array,
Training sequence for autoregressive training
"""
# default sequence data is string
if pa.types.is_string(input_data[seq_attr].type):
input_data_k_split = pc.split_pattern(input_data[seq_attr], seq_str_delim)
# Get target items for training for autoregressive training
# Example: [1,2,3,4,5] -> [2,3,4,5]
input_data_k_split_slice = pc.list_flatten(
pc.list_slice(input_data_k_split, start=1)
)
# Directly extract the training sequence for autoregressive training
# Operation target example: [1,2,3,4,5] -> [1,2,3,4]
# (corresponding target: [2,3,4,5])
# As this can not be achieved by pyarrow.compute, and for loop is costly
# we need to do this using pa.ListArray.from_arrays using offsets
# 1. transfer to numpy and filter out the last item
pre_seq = pc.list_flatten(input_data_k_split).to_numpy(zero_copy_only=False)
# Mark last items of each seq with '-1'
pre_seq[input_data_k_split.offsets.to_numpy()[1:] - 1] = "-1"
# Filter out -1 marker elements
mask = pre_seq != "-1"
pre_seq_filter = pre_seq[mask]
# 2. create offsets for reshaping filtered sequence
# The offsets should be created extract the training sequence
# Example: if the original offsets are [0,2,5,9], after filter,
# for the offsets should be [0, 1, 3, 6]
# that is, [0] + [2-1, 5-2, 9-3]
pre_seq_filter_offsets = pa.array(
np.concatenate(
[
np.array([0]),
input_data_k_split.offsets[1:].to_numpy(zero_copy_only=False)
- np.arange(
1,
len(
input_data_k_split.offsets[1:].to_numpy(
zero_copy_only=False
)
)
+ 1,
),
]
)
)
pre_seq_filter_reshaped = pa.ListArray.from_arrays(
pre_seq_filter_offsets, pre_seq_filter
)
# Join filtered sequence with delimiter
pre_seq_filter_reshaped_joined = pc.binary_join(
pre_seq_filter_reshaped, seq_str_delim
)
return (
input_data_k_split,
input_data_k_split_slice,
pre_seq_filter_reshaped_joined,
)
def process_hstu_neg_sample(
input_data: Dict[str, pa.Array],
v: pa.Array,
neg_sample_num: int,
seq_str_delim: str,
seq_attr: str,
) -> pa.Array:
"""Process negative samples for HSTU match model.
Args:
input_data: Dict[str, pa.Array], Dictionary containing input arrays
v: pa.Array, negative samples.
neg_sample_num: int, number of negative samples.
seq_str_delim: str, delimiter for sequence string.
seq_attr: str, attribute name of sequence.
Returns:
pa.Array: Processed negative samples
"""
# The goal is to make neg samples concat to the training sequence
# Example:
# input_data[seq_attr] = ["1;2;3"]
# neg_sample_num = 2
# v = [4,5,6,7,8,9]
# then the output should be [[1,4,5], [2,6,7], [3,8,9]]
v_str = v.cast(pa.string())
filtered_v_offsets = pa.array(
np.concatenate(
[
np.array([0]),
np.arange(neg_sample_num, len(v_str) + 1, neg_sample_num),
]
)
)
# Reshape v for each input_data[seq_attr]
# Example:[4,5,6,7,8,9] -> [[4,5], [6,7], [8,9]]
filtered_v_palist = pa.ListArray.from_arrays(filtered_v_offsets, v_str)
# Using string for join, as not found operation for ListArray achieving this
# Example: [[4,5], [6,7], [8,9]] -> ["4;5", "6;7", "8;9"]
sampled_joined = pc.binary_join(filtered_v_palist, seq_str_delim)
# Combine training sequence and target items
# Example: ["1;2;3"] + ["4;5", "6;7", "8;9"]
# -> ["1;4;5", "2;6;7", "3;8;9"]
return pc.binary_join_element_wise(
input_data[seq_attr], sampled_joined, seq_str_delim
)
def calc_slice_position(
row_count: int,
slice_id: int,
slice_count: int,
batch_size: int,
drop_redundant_bs_eq_one: bool,
pre_total_remain: int = 0,
) -> Tuple[int, int, int]:
"""Calc table read position according to the slice information.
Args:
row_count (int): table total row count.
slice_id (int): worker id.
slice_count (int): total worker number.
batch_size (int): batch_size.
drop_redundant_bs_eq_one (bool): drop last redundant batch with batch_size
equal one to prevent train_eval hung.
pre_total_remain (int): remaining total count in pre-table is
insufficient to meet the batch_size requirement for each worker.
Return:
start (int): start row position in table.
end (int): start row position in table.
total_remain (int): remaining total count in curr-table is
insufficient to meet the batch_size requirement for each worker.
"""
pre_remain_size = int(pre_total_remain / slice_count)
pre_remain_split_point = pre_total_remain % slice_count
size = int((row_count + pre_total_remain) / slice_count)
split_point = (row_count + pre_total_remain) % slice_count
if slice_id < split_point:
start = slice_id * (size + 1)
end = start + (size + 1)
else:
start = split_point * (size + 1) + (slice_id - split_point) * size
end = start + size
real_start = (
start - pre_remain_size * slice_id - min(pre_remain_split_point, slice_id)
)
real_end = (
end
- pre_remain_size * (slice_id + 1)
- min(pre_remain_split_point, slice_id + 1)
)
# when (end - start) % bz = 1 on some workers and
# (end - start) % bz = 0 on other workers, train_eval will hang
if (
drop_redundant_bs_eq_one
and split_point != 0
and (end - start) % batch_size == 1
and size % batch_size == 0
):
real_end = real_end - 1
split_point = 0
return real_start, real_end, (size % batch_size) * slice_count + split_point
def remove_nullable(field_type: pa.DataType) -> pa.DataType:
"""Recursive removal of the null=False property from lists and nested lists."""
if pa.types.is_list(field_type):
# Get element fields
value_field = field_type.value_field
# Change the nullable to True
normalized_value_field = value_field.with_nullable(True)
# Recursive processing of element types
normalized_value_type = remove_nullable(normalized_value_field.type)
# Construct a new list type
return pa.list_(normalized_value_type)
else:
return field_type