in tensorflow_text/tools/wordpiece_vocab/wordpiece_tokenizer_learner_lib.py [0:0]
def learn_with_thresh(word_counts, thresh, params):
"""Wordpiece learning algorithm to produce a vocab given frequency threshold.
Args:
word_counts: list of (string, int) tuples
thresh: int, frequency threshold for a token to be included in the vocab
params: Params namedtuple, parameters for learning
Returns:
list of strings, vocabulary generated for the given thresh
"""
# Set of single-character tokens.
char_tokens = extract_char_tokens(word_counts)
curr_tokens = ensure_all_tokens_exist(char_tokens, {},
params.include_joiner_token,
params.joiner)
for iteration in range(params.num_iterations):
subtokens = [dict() for _ in range(params.max_token_length + 1)]
# Populate array with counts of each subtoken.
for word, count in word_counts:
if iteration == 0:
split_indices = range(1, len(word) + 1)
else:
split_indices = get_split_indices(word, curr_tokens,
params.include_joiner_token,
params.joiner)
if not split_indices:
continue
start = 0
for index in split_indices:
for end in range(start + 1, len(word) + 1):
subtoken = word[start:end]
length = len(subtoken)
if params.include_joiner_token and start > 0:
subtoken = params.joiner + subtoken
if subtoken in subtokens[length]:
# Subtoken exists, increment count.
subtokens[length][subtoken] += count
else:
# New subtoken, add to dict.
subtokens[length][subtoken] = count
start = index
next_tokens = {}
# Get all tokens that have a count above the threshold.
for length in range(params.max_token_length, 0, -1):
for token, count in subtokens[length].items():
if count >= thresh:
next_tokens[token] = count
# Decrement the count of all prefixes.
if len(token) > length: # This token includes the joiner.
joiner_len = len(params.joiner)
for i in range(1 + joiner_len, length + joiner_len):
prefix = token[0:i]
if prefix in subtokens[i - joiner_len]:
subtokens[i - joiner_len][prefix] -= count
else:
for i in range(1, length):
prefix = token[0:i]
if prefix in subtokens[i]:
subtokens[i][prefix] -= count
# Add back single-character tokens.
curr_tokens = ensure_all_tokens_exist(char_tokens, next_tokens,
params.include_joiner_token,
params.joiner)
vocab_words = generate_final_vocabulary(params.reserved_tokens, char_tokens,
curr_tokens)
return vocab_words