def parse()

in tzrec/datasets/data_parser.py [0:0]


    def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
        """Parse input data dict and build batch.

        Args:
            input_data (dict): raw input data.

        Return:
            output_data (dict): parsed feature data.
        """
        output_data = {}
        if is_input_tile():
            flag = False
            for k, v in input_data.items():
                if self._fg_mode in (FgMode.FG_NONE, FgMode.FG_BUCKETIZE):
                    if k in self.user_feats:
                        input_data[k] = v.take([0])
                else:
                    if k in self.user_inputs:
                        input_data[k] = v.take([0])
                if not flag:
                    output_data["batch_size"] = torch.tensor(v.__len__())
                    flag = True

        if self._fg_mode in (FgMode.FG_DAG, FgMode.FG_BUCKETIZE):
            self._parse_feature_fg_handler(input_data, output_data)
        else:
            self._parse_feature_normal(input_data, output_data)

        for label_name in self._labels:
            label = input_data[label_name]
            if pa.types.is_floating(label.type):
                output_data[label_name] = _to_tensor(
                    label.cast(pa.float32(), safe=False).to_numpy()
                )
            elif pa.types.is_integer(label.type):
                output_data[label_name] = _to_tensor(
                    label.cast(pa.int64(), safe=False).to_numpy()
                )
            else:
                raise ValueError(
                    f"label column [{label_name}] only support int or float dtype now."
                )

        for weight_name in self._sample_weights:
            weight = input_data[weight_name]
            if pa.types.is_floating(weight.type):
                output_data[weight_name] = _to_tensor(
                    weight.cast(pa.float32(), safe=False).to_numpy()
                )
            else:
                raise ValueError(
                    f"sample weight column [{weight_name}] should be float dtype."
                )

        return output_data