def normalize_inputs()

in tensorflow_decision_forests/tensorflow/core.py [0:0]


def normalize_inputs(
    inputs: Dict[str, SemanticTensor]) -> Dict[str, SemanticTensor]:
  """Normalize input tensors for OP consumption.

  Normalization involves:
    - Casting the feature to the "normalized" dtype.
    - Dividing fixed-length features into sets of single dimensional features.
    - Converts the sparse tensors (generally used to represent possible missing
      values) into dense tensors with the correct missing value representation.

  Args:
    inputs: Dict of semantic tensor.

  Returns:
    Dict of normalized semantic tensors.

  Raises:
    ValueError: The arguments are invalid.
  """

  normalized_inputs = {}

  for key, semantic_tensor in inputs.items():
    if semantic_tensor.semantic == Semantic.NUMERICAL:
      if semantic_tensor.tensor.dtype in FlexibleNumericalTypes:
        _unroll_and_normalize(
            tf.cast(semantic_tensor.tensor, tf.float32),
            semantic_tensor.semantic, key, math.nan, normalized_inputs)
      else:
        raise ValueError(
            "Non supported tensor dtype {} for semantic {} of feature {}"
            .format(semantic_tensor.tensor.dtype, semantic_tensor.semantic,
                    key))

    elif semantic_tensor.semantic == Semantic.CATEGORICAL:
      if semantic_tensor.tensor.dtype in FlexibleCategoricalStringTypes:
        _unroll_and_normalize(
            tf.cast(semantic_tensor.tensor, tf.string),
            semantic_tensor.semantic, key, "", normalized_inputs)
      elif semantic_tensor.tensor.dtype in FlexibleCategoricalIntTypes:
        _unroll_and_normalize(
            tf.cast(semantic_tensor.tensor, tf.int32),
            semantic_tensor.semantic,
            key,
            -1 - CATEGORICAL_INTEGER_OFFSET,
            normalized_inputs,
            dense_preprocess=lambda x: x + CATEGORICAL_INTEGER_OFFSET)
      else:
        raise ValueError(
            "Non supported tensor dtype {} for semantic {} of feature {}"
            .format(semantic_tensor.tensor.dtype, semantic_tensor.semantic,
                    key))

    elif semantic_tensor.semantic == Semantic.CATEGORICAL_SET:
      value = semantic_tensor.tensor
      if isinstance(value, tf.SparseTensor):
        value = tf.RaggedTensor.from_sparse(value)

      if semantic_tensor.tensor.dtype in FlexibleCategoricalSetStringTypes:
        normalized_inputs[key] = SemanticTensor(
            semantic=semantic_tensor.semantic, tensor=tf.cast(value, tf.string))
      elif semantic_tensor.tensor.dtype in FlexibleCategoricalSetIntTypes:
        normalized_inputs[key] = SemanticTensor(
            semantic=semantic_tensor.semantic,
            tensor=tf.cast(value, tf.int32) + CATEGORICAL_INTEGER_OFFSET)
      else:
        raise ValueError(
            "Non supported tensor dtype {} for semantic {} of feature {}"
            .format(semantic_tensor.tensor.dtype, semantic_tensor.semantic,
                    key))

    elif semantic_tensor.semantic == Semantic.HASH:
      if semantic_tensor.tensor.dtype in FlexibleHashTypes:
        _unroll_and_normalize(
            tf.cast(semantic_tensor.tensor, tf.string),
            semantic_tensor.semantic, key, "", normalized_inputs)
      else:
        raise ValueError(
            "Non supported tensor dtype {} for semantic {} of feature {}"
            .format(semantic_tensor.tensor.dtype, semantic_tensor.semantic,
                    key))

    elif semantic_tensor.semantic == Semantic.BOOLEAN:
      if semantic_tensor.tensor.dtype in FlexibleBooleanTypes:
        _unroll_and_normalize(
            tf.cast(semantic_tensor.tensor, tf.float32),
            semantic_tensor.semantic, key, math.nan, normalized_inputs)
      else:
        raise ValueError(
            "Non supported tensor dtype {} for semantic {} of feature {}"
            .format(semantic_tensor.tensor.dtype, semantic_tensor.semantic,
                    key))

    else:
      raise ValueError("Non supported semantic {} of feature {}".format(
          semantic_tensor.semantic, key))

  return normalized_inputs