def learn_with_thresh()

in tensorflow_text/tools/wordpiece_vocab/wordpiece_tokenizer_learner_lib.py [0:0]


def learn_with_thresh(word_counts, thresh, params):
  """Wordpiece learning algorithm to produce a vocab given frequency threshold.

  Args:
    word_counts: list of (string, int) tuples
    thresh: int, frequency threshold for a token to be included in the vocab
    params: Params namedtuple, parameters for learning

  Returns:
    list of strings, vocabulary generated for the given thresh
  """

  # Set of single-character tokens.
  char_tokens = extract_char_tokens(word_counts)
  curr_tokens = ensure_all_tokens_exist(char_tokens, {},
                                        params.include_joiner_token,
                                        params.joiner)

  for iteration in range(params.num_iterations):
    subtokens = [dict() for _ in range(params.max_token_length + 1)]
    # Populate array with counts of each subtoken.
    for word, count in word_counts:
      if iteration == 0:
        split_indices = range(1, len(word) + 1)
      else:
        split_indices = get_split_indices(word, curr_tokens,
                                          params.include_joiner_token,
                                          params.joiner)
        if not split_indices:
          continue

      start = 0
      for index in split_indices:
        for end in range(start + 1, len(word) + 1):
          subtoken = word[start:end]
          length = len(subtoken)
          if params.include_joiner_token and start > 0:
            subtoken = params.joiner + subtoken
          if subtoken in subtokens[length]:
            # Subtoken exists, increment count.
            subtokens[length][subtoken] += count
          else:
            # New subtoken, add to dict.
            subtokens[length][subtoken] = count
        start = index

    next_tokens = {}
    # Get all tokens that have a count above the threshold.
    for length in range(params.max_token_length, 0, -1):
      for token, count in subtokens[length].items():
        if count >= thresh:
          next_tokens[token] = count
        # Decrement the count of all prefixes.
        if len(token) > length:  # This token includes the joiner.
          joiner_len = len(params.joiner)
          for i in range(1 + joiner_len, length + joiner_len):
            prefix = token[0:i]
            if prefix in subtokens[i - joiner_len]:
              subtokens[i - joiner_len][prefix] -= count
        else:
          for i in range(1, length):
            prefix = token[0:i]
            if prefix in subtokens[i]:
              subtokens[i][prefix] -= count

    # Add back single-character tokens.
    curr_tokens = ensure_all_tokens_exist(char_tokens, next_tokens,
                                          params.include_joiner_token,
                                          params.joiner)

  vocab_words = generate_final_vocabulary(params.reserved_tokens, char_tokens,
                                          curr_tokens)

  return vocab_words