Sources/Tokenizers/BertTokenizer.swift (224 lines of code) (raw):
//
// BertTokenizer.swift
// CoreMLBert
//
// Created by Julien Chaumond on 27/06/2019.
// Copyright © 2019 Hugging Face. All rights reserved.
//
import Foundation
import Hub
public class BertTokenizer {
private let basicTokenizer: BasicTokenizer
private let wordpieceTokenizer: WordpieceTokenizer
private let maxLen = 512
private let tokenizeChineseChars: Bool
private let vocab: [String: Int]
private let ids_to_tokens: [Int: String]
public var bosToken: String?
public var bosTokenId: Int?
public var eosToken: String?
public var eosTokenId: Int?
public let fuseUnknownTokens: Bool
public init(
vocab: [String: Int],
merges: [String]?,
tokenizeChineseChars: Bool = true,
bosToken: String? = nil,
eosToken: String? = nil,
fuseUnknownTokens: Bool = false,
doLowerCase: Bool = true
) {
self.vocab = vocab
ids_to_tokens = Utils.invert(vocab)
basicTokenizer = BasicTokenizer(doLowerCase: doLowerCase)
wordpieceTokenizer = WordpieceTokenizer(vocab: self.vocab)
self.tokenizeChineseChars = tokenizeChineseChars
self.bosToken = bosToken
bosTokenId = bosToken == nil ? nil : vocab[bosToken!]
self.eosToken = eosToken
eosTokenId = eosToken == nil ? nil : vocab[eosToken!]
self.fuseUnknownTokens = fuseUnknownTokens
}
public required convenience init(tokenizerConfig: Config, tokenizerData: Config, addedTokens: [String: Int]) throws {
guard let vocab = tokenizerData.model.vocab.dictionary() else {
throw TokenizerError.missingVocab
}
let merges: [String]? = tokenizerData.model.merges.get()
let tokenizeChineseChars = tokenizerConfig.handleChineseChars.boolean(or: true)
let eosToken = tokenizerConfig.eosToken.string()
let bosToken = tokenizerConfig.bosToken.string()
let fuseUnknown = tokenizerConfig.fuseUnk.boolean(or: false)
let doLowerCase = tokenizerConfig.doLowerCase.boolean(or: true)
var vocabulary = vocab.reduce(into: [String: Int]()) { result, element in
if let val = element.value.integer() {
result[element.key.string] = val
}
}
if let pairs = tokenizerData.addedTokens.array()?.reduce(into: [String: Int](), { result, element in
guard let val = element["id"].integer() else { return }
guard let key = element["content"].string() else { return }
result[key] = val
}) {
vocabulary.merge(pairs, uniquingKeysWith: { $1 })
}
vocabulary.merge(addedTokens, uniquingKeysWith: { $1 })
self.init(
vocab: vocabulary, merges: merges, tokenizeChineseChars: tokenizeChineseChars, bosToken: bosToken, eosToken: eosToken,
fuseUnknownTokens: fuseUnknown, doLowerCase: doLowerCase
)
}
public func tokenize(text: String) -> [String] {
let text = tokenizeChineseCharsIfNeed(text)
var tokens: [String] = []
for token in basicTokenizer.tokenize(text: text) {
for subToken in wordpieceTokenizer.tokenize(word: token) {
tokens.append(subToken)
}
}
return tokens
}
private func convertTokensToIds(tokens: [String]) throws -> [Int] {
if tokens.count > maxLen {
throw TokenizerError.tooLong(
"""
Token indices sequence length is longer than the specified maximum
sequence length for this BERT model (\(tokens.count) > \(maxLen). Running this
sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
"""
)
}
return tokens.compactMap { vocab[$0] }
}
/// Main entry point
func tokenizeToIds(text: String) -> [Int] {
try! convertTokensToIds(tokens: tokenize(text: text))
}
func tokenToId(token: String) -> Int {
vocab[token]!
}
/// Un-tokenization: get tokens from tokenIds
func unTokenize(tokens: [Int]) -> [String] {
tokens.compactMap { ids_to_tokens[$0] }
}
/// Un-tokenization:
func convertWordpieceToBasicTokenList(_ wordpieceTokenList: [String]) -> String {
var tokenList: [String] = []
var individualToken = ""
for token in wordpieceTokenList {
if token.starts(with: "##") {
individualToken += String(token.suffix(token.count - 2))
} else {
if individualToken.count > 0 {
tokenList.append(individualToken)
}
individualToken = token
}
}
tokenList.append(individualToken)
return tokenList.joined(separator: " ")
}
private func tokenizeChineseCharsIfNeed(_ text: String) -> String {
guard tokenizeChineseChars else {
return text
}
return text.map { c in
if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) {
" \(c) "
} else {
"\(c)"
}
}.joined()
}
}
extension BertTokenizer: PreTrainedTokenizerModel {
public var unknownToken: String? { wordpieceTokenizer.unkToken }
public var unknownTokenId: Int? { vocab[unknownToken!] }
func encode(text: String) -> [Int] { tokenizeToIds(text: text) }
func decode(tokens: [Int]) -> String {
let tokens = unTokenize(tokens: tokens)
return convertWordpieceToBasicTokenList(tokens)
}
public func convertTokenToId(_ token: String) -> Int? {
vocab[token] ?? unknownTokenId
}
public func convertIdToToken(_ id: Int) -> String? {
ids_to_tokens[id]
}
}
class BasicTokenizer {
let doLowerCase: Bool
init(doLowerCase: Bool = true) {
self.doLowerCase = doLowerCase
}
let neverSplit = [
"[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]",
]
func maybeStripAccents(_ text: String) -> String {
guard doLowerCase else { return text }
return text.folding(options: .diacriticInsensitive, locale: nil)
}
func maybeLowercase(_ text: String) -> String {
guard doLowerCase else { return text }
return text.lowercased()
}
func tokenize(text: String) -> [String] {
let splitTokens = maybeStripAccents(text).components(separatedBy: NSCharacterSet.whitespaces)
let tokens = splitTokens.flatMap { (token: String) -> [String] in
if neverSplit.contains(token) {
return [token]
}
var toks: [String] = []
var currentTok = ""
for c in maybeLowercase(token) {
if !c.isExtendedPunctuation {
currentTok += String(c)
} else if currentTok.count > 0 {
toks.append(currentTok)
toks.append(String(c))
currentTok = ""
} else {
toks.append(String(c))
}
}
if currentTok.count > 0 {
toks.append(currentTok)
}
return toks
}
return tokens
}
}
extension Character {
/// https://github.com/huggingface/transformers/blob/8c1b5d37827a6691fef4b2d926f2d04fb6f5a9e3/src/transformers/tokenization_utils.py#L367
var isExtendedPunctuation: Bool {
if isPunctuation { return true }
if let value = unicodeScalars.first?.value {
switch value {
case 33...47: return true
case 58...64: return true
case 91...96: return true
case 123...126: return true
default: return false
}
}
return false
}
}
class WordpieceTokenizer {
let unkToken = "[UNK]"
private let maxInputCharsPerWord = 100
private let vocab: [String: Int]
init(vocab: [String: Int]) {
self.vocab = vocab
}
/// `word`: A single token.
/// Warning: this differs from the `pytorch-transformers` implementation.
/// This should have already been passed through `BasicTokenizer`.
func tokenize(word: String) -> [String] {
if word.count > maxInputCharsPerWord {
return [unkToken]
}
var outputTokens: [String] = []
var isBad = false
var start = 0
var subTokens: [String] = []
while start < word.count {
var end = word.count
var cur_substr: String?
while start < end {
var substr = Utils.substr(word, start..<end)!
if start > 0 {
substr = "##\(substr)"
}
if vocab[substr] != nil {
cur_substr = substr
break
}
end -= 1
}
if cur_substr == nil {
isBad = true
break
}
subTokens.append(cur_substr!)
start = end
}
if isBad {
outputTokens.append(unkToken)
} else {
outputTokens.append(contentsOf: subTokens)
}
return outputTokens
}
}