def collect_training_examples()

in tensorflow_decision_forests/tensorflow/core.py [0:0]


def collect_training_examples(
    inputs: Dict[str, SemanticTensor],
    model_id: str,
    collect_training_data: Optional[bool] = True) -> tf.Operation:
  """Collects a batch of training examples.

  The features values are append to a set of column-wise in-memory accumulators
  contained in tf resources with respective names "_input_key_to_id(model_id,
  key)".

  Args:
    inputs: Features to collect.
    model_id: Id of the model.
    collect_training_data: Indicate if the examples are used for training.

  Returns:
    Op triggering the collection.
  """

  ops = []
  for key, semantic_tensor in inputs.items():

    def raise_non_supported():
      raise Exception(
          "Non supported tensor dtype {} and semantic {} for feature {}".format(
              semantic_tensor.tensor.dtype, semantic_tensor.semantic, key))  # pylint: disable=cell-var-from-loop

    input_id = _input_key_to_id(model_id, key, collect_training_data)
    if semantic_tensor.semantic == Semantic.NUMERICAL:
      if semantic_tensor.tensor.dtype == NormalizedNumericalType:
        ops.append(
            training_op.simple_ml_numerical_feature(
                value=semantic_tensor.tensor, id=input_id, feature_name=key))
      else:
        raise_non_supported()

    elif semantic_tensor.semantic == Semantic.CATEGORICAL:
      if semantic_tensor.tensor.dtype == NormalizedCategoricalStringType:
        ops.append(
            training_op.simple_ml_categorical_string_feature(
                value=semantic_tensor.tensor, id=input_id, feature_name=key))
      elif semantic_tensor.tensor.dtype == NormalizedCategoricalIntType:
        ops.append(
            training_op.simple_ml_categorical_int_feature(
                value=semantic_tensor.tensor, id=input_id, feature_name=key))
      else:
        raise_non_supported()

    elif semantic_tensor.semantic == Semantic.CATEGORICAL_SET:
      args = {
          "values": semantic_tensor.tensor.values,
          "row_splits": semantic_tensor.tensor.row_splits,
          "id": input_id,
          "feature_name": key
      }
      if semantic_tensor.tensor.dtype == NormalizedCategoricalSetStringType:
        ops.append(training_op.simple_ml_categorical_set_string_feature(**args))
      elif semantic_tensor.tensor.dtype == NormalizedCategoricalIntType:
        ops.append(training_op.simple_ml_categorical_set_int_feature(**args))
      else:
        raise_non_supported()

    elif semantic_tensor.semantic == Semantic.HASH:
      if semantic_tensor.tensor.dtype == NormalizedHashType:
        ops.append(
            training_op.simple_ml_hash_feature(
                value=semantic_tensor.tensor, id=input_id, feature_name=key))
      else:
        raise_non_supported()

    elif semantic_tensor.semantic == Semantic.BOOLEAN:
      # Boolean features are not yet supported for training in TF-DF.
      raise_non_supported()

    else:
      raise_non_supported()

  return tf.group(ops)