std::vector GPT2BPEEncoder::BPE_()

in torchtext/csrc/gpt2_bpe_tokenizer.cpp [189:253]


std::vector<std::string> GPT2BPEEncoder::BPE_(
    const std::vector<std::string> &token_list) {
  // Given a list of input tokens, keep finding the best bpe merge and
  // generate a new list of tokens until
  //  1) token list size reduced to 1
  //      OR
  //  2) can't find bpe merge
  auto concatenated = _concatenate_strings(token_list);
  if (caching_enabled_ && cache_.contains(concatenated)) {
    return cache_.at(concatenated);
  }

  std::vector<std::string> tok_list = token_list;
  auto pairs = _get_pairs(tok_list, seperator_);
  if (pairs.empty()) {
    return tok_list;
  }
  while (true) {
    auto bigram = FindBestPair_(pairs);
    if (!bpe_merge_ranks_.contains(bigram)) break;

    // Finding all indexes that token_list[i] == first and token_list[i+1] ==
    // second. After the loop, new token list will be
    //  1) first + second pair
    //  2) all the other tokens in the original token list
    //
    // For example: first="a" second="w" and token_list =
    // ["a", "w", "some", "a", "w", "e"]
    // Result: new_token_list = ["aw", "some", "aw", "e"]

    auto parts = _split_tokens(bigram, seperator_);
    std::vector<std::string> new_token_list;
    std::size_t i = 0;
    while (i < tok_list.size()) {
      auto j = _list_str_index(tok_list, parts.first, i);
      if (j != -1) {
        for (int k = i; k < j; k++) new_token_list.push_back(tok_list[k]);
        i = j;
      } else {
        for (std::size_t k = i; k < tok_list.size(); k++)
          new_token_list.push_back(tok_list[k]);
        break;
      }

      if (tok_list[i] == parts.first && i < (tok_list.size() - 1) &&
          tok_list[i + 1] == parts.second) {
        new_token_list.push_back(parts.first + parts.second);
        i += 2;
      } else {
        new_token_list.push_back(tok_list[i]);
        i += 1;
      }
    }

    tok_list = new_token_list;
    if (tok_list.size() == 1) {
      break;
    } else {
      pairs = _get_pairs(tok_list, seperator_);
    }
  }

  if (caching_enabled_) cache_.insert(concatenated, tok_list);
  return tok_list;
}