in learn_bpe.py [0:0]
def pop_max(self):
"""
Extract the most frequent element, create a new pair and adjust counts.
"""
heap_max = self.stats_heap.pop_max_cached_value()
max_elem = heap_max.value
self.stats_heap.increase_timestep()
freq, pair = max_elem
first, second = pair
pair_str = first + second
if self.probabilistic:
self.n_running_symbols -= freq
self.produced_count[first] -= freq
self.produced_count[second] -= freq
self.produced_count[pair_str] += freq
# This approach is taken from the original implementation. We could
# probably optimize this processing and try to avoid jumping between
# pairs and strings, as well as avoiding re.
pair_str = pair_str.replace('\\', '\\\\')
pattern = re.compile(r'(?<!\S)' + re.escape(first + ' ' + second) + r'(?!\S)')
stats_changes = {}
for w_index in self.vocab_entries_for_pair[pair]:
# Update words
old_word = self.vocab.get_word(w_index)
new_word = " ".join(old_word)
new_word = pattern.sub(pair_str, new_word)
new_word = tuple(new_word.split(" "))
self.vocab.substitute(old_word, new_word)
freqs = self.vocab.get_counts_from_index(w_index)
self._update_stats(pair, old_word, new_word, freqs, w_index, stats_changes)
updated_stats = {}
for mod_pair, freq_change in stats_changes.items():
heap_entry = self.stats_heap.get(mod_pair)
if not heap_entry:
updated_stats[mod_pair] = freq_change
else:
updated_stats[mod_pair] = heap_entry[0] + freq_change
self.stats_heap.invalidate_key(mod_pair)
for mod_pair, freq in updated_stats.items():
if freq > 0:
self.stats_heap.insert((freq, mod_pair))
del self.vocab_entries_for_pair[pair]
return max_elem, heap_max.score