in tensorflow_decision_forests/tensorflow/core.py [0:0]
def collect_training_examples(
inputs: Dict[str, SemanticTensor],
model_id: str,
collect_training_data: Optional[bool] = True) -> tf.Operation:
"""Collects a batch of training examples.
The features values are append to a set of column-wise in-memory accumulators
contained in tf resources with respective names "_input_key_to_id(model_id,
key)".
Args:
inputs: Features to collect.
model_id: Id of the model.
collect_training_data: Indicate if the examples are used for training.
Returns:
Op triggering the collection.
"""
ops = []
for key, semantic_tensor in inputs.items():
def raise_non_supported():
raise Exception(
"Non supported tensor dtype {} and semantic {} for feature {}".format(
semantic_tensor.tensor.dtype, semantic_tensor.semantic, key)) # pylint: disable=cell-var-from-loop
input_id = _input_key_to_id(model_id, key, collect_training_data)
if semantic_tensor.semantic == Semantic.NUMERICAL:
if semantic_tensor.tensor.dtype == NormalizedNumericalType:
ops.append(
training_op.simple_ml_numerical_feature(
value=semantic_tensor.tensor, id=input_id, feature_name=key))
else:
raise_non_supported()
elif semantic_tensor.semantic == Semantic.CATEGORICAL:
if semantic_tensor.tensor.dtype == NormalizedCategoricalStringType:
ops.append(
training_op.simple_ml_categorical_string_feature(
value=semantic_tensor.tensor, id=input_id, feature_name=key))
elif semantic_tensor.tensor.dtype == NormalizedCategoricalIntType:
ops.append(
training_op.simple_ml_categorical_int_feature(
value=semantic_tensor.tensor, id=input_id, feature_name=key))
else:
raise_non_supported()
elif semantic_tensor.semantic == Semantic.CATEGORICAL_SET:
args = {
"values": semantic_tensor.tensor.values,
"row_splits": semantic_tensor.tensor.row_splits,
"id": input_id,
"feature_name": key
}
if semantic_tensor.tensor.dtype == NormalizedCategoricalSetStringType:
ops.append(training_op.simple_ml_categorical_set_string_feature(**args))
elif semantic_tensor.tensor.dtype == NormalizedCategoricalIntType:
ops.append(training_op.simple_ml_categorical_set_int_feature(**args))
else:
raise_non_supported()
elif semantic_tensor.semantic == Semantic.HASH:
if semantic_tensor.tensor.dtype == NormalizedHashType:
ops.append(
training_op.simple_ml_hash_feature(
value=semantic_tensor.tensor, id=input_id, feature_name=key))
else:
raise_non_supported()
elif semantic_tensor.semantic == Semantic.BOOLEAN:
# Boolean features are not yet supported for training in TF-DF.
raise_non_supported()
else:
raise_non_supported()
return tf.group(ops)