def _build_batch()

in tzrec/datasets/dataset.py [0:0]


    def _build_batch(self, input_data: Dict[str, pa.Array]) -> Batch:
        """Process input data and build batch.

        Args:
            input_data (dict): raw input data.

        Returns:
            an instance of Batch.
        """
        use_sample_mask = self._mode == Mode.TRAIN and (
            self._data_config.negative_sample_mask_prob > 0
            or self._data_config.sample_mask_prob > 0
        )
        if use_sample_mask:
            input_data[C_SAMPLE_MASK] = pa.array(
                np.random.random(len(list(input_data.values())[0]))
                < self._data_config.sample_mask_prob
            )
        if self._sampler is not None:
            if isinstance(self._sampler, TDMSampler):
                pos_sampled, neg_sampled = self._sampler.get(input_data)
                input_data = _expand_tdm_sample(
                    input_data, pos_sampled, neg_sampled, self._data_config
                )
            elif self._enable_hstu:
                seq_attr = self._sampler._item_id_field

                (
                    input_data_k_split,
                    input_data_k_split_slice,
                    pre_seq_filter_reshaped_joined,
                ) = process_hstu_seq_data(
                    input_data=input_data,
                    seq_attr=seq_attr,
                    seq_str_delim=self._sampler.item_id_delim,
                )
                if self._mode == Mode.TRAIN:
                    # Training using all possible target items
                    input_data[seq_attr] = input_data_k_split_slice
                elif self._mode == Mode.EVAL:
                    # Evaluation using the last item for previous sequence
                    input_data[seq_attr] = input_data_k_split.values.take(
                        pa.array(input_data_k_split.offsets.to_numpy()[1:] - 1)
                    )
                sampled = self._sampler.get(input_data)
                # To keep consistent with other process, use two functions
                for k, v in sampled.items():
                    if k in input_data:
                        combined = process_hstu_neg_sample(
                            input_data,
                            v,
                            self._sampler._num_sample,
                            self._sampler.item_id_delim,
                            seq_attr,
                        )
                        # Combine here to make embddings of both user sequence
                        # and target item are the same
                        input_data[k] = pa.concat_arrays(
                            [pre_seq_filter_reshaped_joined, combined]
                        )
                    else:
                        input_data[k] = v
            else:
                sampled = self._sampler.get(input_data)
                for k, v in sampled.items():
                    if k in input_data:
                        input_data[k] = pa.concat_arrays([input_data[k], v])
                    else:
                        input_data[k] = v

            if use_sample_mask:
                input_data[C_NEG_SAMPLE_MASK] = pa.concat_arrays(
                    [
                        input_data[C_SAMPLE_MASK],
                        pa.array(
                            np.random.random(len(list(sampled.values())[0]))
                            < self._data_config.negative_sample_mask_prob
                        ),
                    ]
                )

        # TODO(hongsheng.jhs): add additional field like hard_negative
        output_data = self._data_parser.parse(input_data)
        if self._mode == Mode.PREDICT:
            batch = self._data_parser.to_batch(output_data, force_no_tile=True)
            reserved_data = {}
            if (
                len(self._reserved_columns) > 0
                and self._reserved_columns[0] == "ALL_COLUMNS"
            ):
                reserved_data = input_data
            else:
                for k in self._reserved_columns:
                    reserved_data[k] = input_data[k]
            if self._debug_level > 0:
                reserved_data["__features__"] = self._data_parser.dump_parsed_inputs(
                    output_data
                )
            if len(reserved_data) > 0:
                batch.reserves = RecordBatchTensor(pa.record_batch(reserved_data))
        else:
            batch = self._data_parser.to_batch(output_data)
        return batch