def apply_vocabulary()

in tensorflow_transform/mappers.py [0:0]


def apply_vocabulary(
    x: common_types.ConsistentTensorType,
    deferred_vocab_filename_tensor: common_types.TemporaryAnalyzerOutputType,
    default_value: Any = -1,
    num_oov_buckets: int = 0,
    lookup_fn: Optional[Callable[[common_types.TensorType, tf.Tensor],
                                 Tuple[tf.Tensor, tf.Tensor]]] = None,
    file_format: common_types.VocabularyFileFormatType = analyzers
    .DEFAULT_VOCABULARY_FILE_FORMAT,
    name: Optional[str] = None) -> common_types.ConsistentTensorType:
  r"""Maps `x` to a vocabulary specified by the deferred tensor.

  This function also writes domain statistics about the vocabulary min and max
  values. Note that the min and max are inclusive, and depend on the vocab size,
  num_oov_buckets and default_value.

  Args:
    x: A categorical `Tensor` or `CompositeTensor` of type tf.string or
      tf.int[8|16|32|64] to which the vocabulary transformation should be
      applied. The column names are those intended for the transformed tensors.
    deferred_vocab_filename_tensor: The deferred vocab filename tensor as
      returned by `tft.vocabulary`, as long as the frequencies were not stored.
    default_value: The value to use for out-of-vocabulary values, unless
      'num_oov_buckets' is greater than zero.
    num_oov_buckets:  Any lookup of an out-of-vocabulary token will return a
      bucket ID based on its hash if `num_oov_buckets` is greater than zero.
      Otherwise it is assigned the `default_value`.
    lookup_fn: Optional lookup function, if specified it should take a tensor
      and a deferred vocab filename as an input and return a lookup `op` along
      with the table size, by default `apply_vocabulary` constructs a
      StaticHashTable for the table lookup.
    file_format: (Optional) A str. The format of the given vocabulary.
      Accepted formats are: 'tfrecord_gzip', 'text'.
      The default value is 'text'.
    name: (Optional) A name for this operation.

  Returns:
    A `Tensor` or `CompositeTensor` where each string value is mapped to an
    integer. Each unique string value that appears in the vocabulary
    is mapped to a different integer and integers are consecutive
    starting from zero, and string value not in the vocabulary is
    assigned default_value.
  """
  if (file_format == 'tfrecord_gzip' and
      not tf_utils.is_vocabulary_tfrecord_supported()):
    raise ValueError(
        'Vocabulary file_format "tfrecord_gzip" not yet supported for '
        f'{tf.version.VERSION}.')
  with tf.compat.v1.name_scope(name, 'apply_vocab'):
    if x.dtype != tf.string and not x.dtype.is_integer:
      raise ValueError('expected tf.string or tf.int[8|16|32|64] but got %r' %
                       x.dtype)

    if lookup_fn:
      result, table_size = tf_utils.lookup_table(
          lookup_fn, deferred_vocab_filename_tensor, x)
    else:
      if (deferred_vocab_filename_tensor is None or
          (isinstance(deferred_vocab_filename_tensor,
                      (bytes, str)) and not deferred_vocab_filename_tensor)):
        raise ValueError('`deferred_vocab_filename_tensor` must not be empty.')

      def _construct_table(asset_filepath):
        if file_format == 'tfrecord_gzip':
          initializer = tf_utils.make_tfrecord_vocabulary_lookup_initializer(
              asset_filepath, x.dtype)
        elif file_format == 'text':
          initializer = tf.lookup.TextFileInitializer(
              asset_filepath,
              key_dtype=x.dtype,
              key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
              value_dtype=tf.int64,
              value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
        else:
          raise ValueError(
              '"{}" is not an accepted file_format. It should be one of: {}'
              .format(file_format, analyzers.ALLOWED_VOCABULARY_FILE_FORMATS))

        if num_oov_buckets > 0:
          table = tf.lookup.StaticVocabularyTable(
              initializer,
              num_oov_buckets=num_oov_buckets,
              lookup_key_dtype=x.dtype)
        else:
          table = tf.lookup.StaticHashTable(
              initializer, default_value=default_value)
        return table

      compose_result_fn = _make_composite_tensor_wrapper_if_composite(x)
      x_values = _get_values_if_composite(x)
      result, table_size = tf_utils.construct_and_lookup_table(
          _construct_table, deferred_vocab_filename_tensor, x_values)
      result = compose_result_fn(result)

    # Specify schema overrides which will override the values in the schema
    # with the min and max values, which are deferred as they are only known
    # once the analyzer has run.
    #
    # `table_size` includes the num oov buckets.  The default value is only used
    # if num_oov_buckets <= 0.
    min_value = tf.constant(0, tf.int64)
    max_value = table_size - 1
    if num_oov_buckets <= 0:
      min_value = tf.minimum(min_value, default_value)
      max_value = tf.maximum(max_value, default_value)
    schema_inference.set_tensor_schema_override(
        _get_values_if_composite(result), min_value, max_value)
    return result