in torchbenchmark/models/attention_is_all_you_need_pytorch/apply_bpe.py [0:0]
def encode(orig, bpe_codes, bpe_codes_reverse, vocab, separator, version, cache, glossaries_regex=None, dropout=0):
"""Encode word based on list of BPE merge operations, which are applied consecutively
"""
if not dropout and orig in cache:
return cache[orig]
if glossaries_regex and glossaries_regex.match(orig):
cache[orig] = (orig,)
return (orig,)
if len(orig) == 1:
return orig
if version == (0, 1):
word = list(orig) + ['</w>']
elif version == (0, 2): # more consistent handling of word-final segments
word = list(orig[:-1]) + [orig[-1] + '</w>']
else:
raise NotImplementedError
while len(word) > 1:
# get list of symbol pairs; optionally apply dropout
pairs = [(bpe_codes[pair],i,pair) for (i,pair) in enumerate(zip(word, word[1:])) if (not dropout or random.random() > dropout) and pair in bpe_codes]
if not pairs:
break
#get first merge operation in list of BPE codes
bigram = min(pairs)[2]
# find start position of all pairs that we want to merge
positions = [i for (rank,i,pair) in pairs if pair == bigram]
i = 0
new_word = []
bigram = ''.join(bigram)
for j in positions:
# merges are invalid if they start before current position. This can happen if there are overlapping pairs: (x x x -> xx x)
if j < i:
continue
new_word.extend(word[i:j]) # all symbols before merged pair
new_word.append(bigram) # merged pair
i = j+2 # continue after merged pair
new_word.extend(word[i:]) # add all symbols until end of word
word = new_word
# don't print end-of-word symbols
if word[-1] == '</w>':
word = word[:-1]
elif word[-1].endswith('</w>'):
word[-1] = word[-1][:-4]
word = tuple(word)
if vocab:
word = check_vocab_and_split(word, bpe_codes_reverse, vocab, separator)
cache[orig] = word
return word