in tzrec/datasets/data_parser.py [0:0]
def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
"""Parse input data dict and build batch.
Args:
input_data (dict): raw input data.
Return:
output_data (dict): parsed feature data.
"""
output_data = {}
if is_input_tile():
flag = False
for k, v in input_data.items():
if self._fg_mode in (FgMode.FG_NONE, FgMode.FG_BUCKETIZE):
if k in self.user_feats:
input_data[k] = v.take([0])
else:
if k in self.user_inputs:
input_data[k] = v.take([0])
if not flag:
output_data["batch_size"] = torch.tensor(v.__len__())
flag = True
if self._fg_mode in (FgMode.FG_DAG, FgMode.FG_BUCKETIZE):
self._parse_feature_fg_handler(input_data, output_data)
else:
self._parse_feature_normal(input_data, output_data)
for label_name in self._labels:
label = input_data[label_name]
if pa.types.is_floating(label.type):
output_data[label_name] = _to_tensor(
label.cast(pa.float32(), safe=False).to_numpy()
)
elif pa.types.is_integer(label.type):
output_data[label_name] = _to_tensor(
label.cast(pa.int64(), safe=False).to_numpy()
)
else:
raise ValueError(
f"label column [{label_name}] only support int or float dtype now."
)
for weight_name in self._sample_weights:
weight = input_data[weight_name]
if pa.types.is_floating(weight.type):
output_data[weight_name] = _to_tensor(
weight.cast(pa.float32(), safe=False).to_numpy()
)
else:
raise ValueError(
f"sample weight column [{weight_name}] should be float dtype."
)
return output_data