in tzrec/datasets/dataset.py [0:0]
def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
"""Process input data and build batch.
Args:
input_data (dict): raw input data.
Returns:
an instance of Batch.
"""
use_sample_mask = self._mode == Mode.TRAIN and (
self._data_config.negative_sample_mask_prob > 0
or self._data_config.sample_mask_prob > 0
)
if use_sample_mask:
input_data[C_SAMPLE_MASK] = pa.array(
np.random.random(len(list(input_data.values())[0]))
< self._data_config.sample_mask_prob
)
if self._sampler is not None:
if isinstance(self._sampler, TDMSampler):
pos_sampled, neg_sampled = self._sampler.get(input_data)
input_data = _expand_tdm_sample(
input_data, pos_sampled, neg_sampled, self._data_config
)
elif self._enable_hstu:
seq_attr = self._sampler._item_id_field
(
input_data_k_split,
input_data_k_split_slice,
pre_seq_filter_reshaped_joined,
) = process_hstu_seq_data(
input_data=input_data,
seq_attr=seq_attr,
seq_str_delim=self._sampler.item_id_delim,
)
if self._mode == Mode.TRAIN:
# Training using all possible target items
input_data[seq_attr] = input_data_k_split_slice
elif self._mode == Mode.EVAL:
# Evaluation using the last item for previous sequence
input_data[seq_attr] = input_data_k_split.values.take(
pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1)
)
sampled = self._sampler.get(input_data)
# To keep consistent with other process, use two functions
for k, v in sampled.items():
if k in input_data:
combined = process_hstu_neg_sample(
input_data,
v,
self._sampler._num_sample,
self._sampler.item_id_delim,
seq_attr,
)
# Combine here to make embddings of both user sequence
# and target item are the same
input_data[k] = pa.concat_arrays(
[pre_seq_filter_reshaped_joined, combined]
)
else:
input_data[k] = v
else:
sampled = self._sampler.get(input_data)
for k, v in sampled.items():
if k in input_data:
input_data[k] = pa.concat_arrays([input_data[k], v])
else:
input_data[k] = v
if use_sample_mask:
input_data[C_NEG_SAMPLE_MASK] = pa.concat_arrays(
[
input_data[C_SAMPLE_MASK],
pa.array(
np.random.random(len(list(sampled.values())[0]))
< self._data_config.negative_sample_mask_prob
),
]
)
# TODO(hongsheng.jhs): add additional field like hard_negative
output_data = self._data_parser.parse(input_data)
if self._mode == Mode.PREDICT:
batch = self._data_parser.to_batch(output_data, force_no_tile=True)
reserved_data = {}
if (
len(self._reserved_columns) > 0
and self._reserved_columns[0] == "ALL_COLUMNS"
):
reserved_data = input_data
else:
for k in self._reserved_columns:
reserved_data[k] = input_data[k]
if self._debug_level > 0:
reserved_data["__features__"] = self._data_parser.dump_parsed_inputs(
output_data
)
if len(reserved_data) > 0:
batch.reserves = RecordBatchTensor(pa.record_batch(reserved_data))
else:
batch = self._data_parser.to_batch(output_data)
return batch