Sources/Tokenizers/Decoder.swift (207 lines of code) (raw):

// // Decoder.swift // // // Created by Pedro Cuenca on 17/7/23. // import Foundation import Hub public protocol Decoder { func decode(tokens: [String]) -> [String] func callAsFunction(tokens: [String]) -> [String] init(config: Config) } extension Decoder { func callAsFunction(tokens: [String]) -> [String] { decode(tokens: tokens) } } enum DecoderType: String { case Sequence case WordPiece case ByteLevel case Replace case ByteFallback case Fuse case Strip case Metaspace case Unknown = "" } struct DecoderFactory { static func fromConfig(config: Config?, addedTokens: Set<String>? = nil) -> Decoder? { // TODO: not sure if we need to include `addedTokens` in all the decoder initializers (and the protocol) guard let config else { return nil } guard let typeName = config.type.string() else { return nil } let type = DecoderType(rawValue: typeName) switch type { case .Sequence: return DecoderSequence(config: config) case .ByteLevel: return ByteLevelDecoder(config: config, addedTokens: addedTokens) case .Replace: return ReplaceDecoder(config: config) case .ByteFallback: return ByteFallbackDecoder(config: config) case .Fuse: return FuseDecoder(config: config) case .Strip: return StripDecoder(config: config) case .Metaspace: return MetaspaceDecoder(config: config) case .WordPiece: return WordPieceDecoder(config: config) default: fatalError("Unsupported Decoder type: \(typeName)") } } } class WordPieceDecoder: Decoder { let prefix: String let cleanup: Bool /// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L31 private let re = try! NSRegularExpression(pattern: "\\s(\\.|\\?|\\!|\\,|'\\s|n't|'m|'s|'ve|'re)", options: []) public required init(config: Config) { guard let prefix = config.prefix.string() else { fatalError("Missing `prefix` configuration for WordPieceDecoder.") } self.prefix = prefix cleanup = config.cleanup.boolean(or: false) } func decode(tokens: [String]) -> [String] { let firstToken = cleanup ? cleanUpTokenization(tokens.first!) : tokens.first! return [firstToken] + tokens.dropFirst().map { token in let token = token.hasPrefix(prefix) ? token.replacingCharacters(in: token.range(of: prefix)!, with: "") : " \(token)" return cleanup ? cleanUpTokenization(token) : token } } /// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/decoders/wordpiece.rs#L40 private func cleanUpTokenization(_ token: String) -> String { let range = NSRange(location: 0, length: token.utf16.count) return re.stringByReplacingMatches(in: token, options: [], range: range, withTemplate: "$1") .replacingOccurrences(of: " do not", with: " don't") } } class DecoderSequence: Decoder { let decoders: [Decoder] public required init(config: Config) { guard let configs = config.decoders.array() else { fatalError("No decoders in Sequence") } decoders = configs.compactMap { DecoderFactory.fromConfig(config: $0) } } func decode(tokens: [String]) -> [String] { decoders.reduce(tokens) { current, decoder in decoder(tokens: current) } } } class ByteLevelDecoder: Decoder { let addedTokens: Set<String> public required init(config: Config) { addedTokens = [] } init(config: Config, addedTokens: Set<String>?) { self.addedTokens = addedTokens ?? [] } func decode(tokens: [String]) -> [String] { var subTexts: [String] = [] var currentSubText: [String] = [] func convertTokensToString(_ tokens: [String]) -> String { let text = tokens.joined(separator: "") let utfCodepoints = text.map { byteDecoder[String($0)]! } return String(decoding: utfCodepoints, as: UTF8.self) } for token in tokens { if addedTokens.contains(token) { if !currentSubText.isEmpty { subTexts.append(convertTokensToString(currentSubText)) currentSubText = [] } subTexts.append(token) } else { currentSubText.append(token) } } if !currentSubText.isEmpty { subTexts.append(convertTokensToString(currentSubText)) } return subTexts } } class ReplaceDecoder: Decoder { let pattern: StringReplacePattern? public required init(config: Config) { pattern = StringReplacePattern.from(config: config) } func decode(tokens: [String]) -> [String] { guard let pattern else { return tokens } return tokens.map { pattern.replace($0) } } } class ByteFallbackDecoder: Decoder { public required init(config: Config) { } func decode(tokens: [String]) -> [String] { var newTokens: [String] = [] var byteTokens: [Int] = [] func parseByte(_ token: String) -> Int? { guard token.count == 6, token.hasPrefix("<0x"), token.hasSuffix(">") else { return nil } let startIndex = token.index(token.startIndex, offsetBy: 3) let endIndex = token.index(token.startIndex, offsetBy: 5) return Int(token[startIndex..<endIndex], radix: 16) } for token in tokens { if let byte = parseByte(token) { byteTokens.append(byte) } else { if !byteTokens.isEmpty { // decode as utf8 and append let codeUnits = byteTokens.map { UTF8.CodeUnit($0) } newTokens.append(String(decoding: codeUnits, as: UTF8.self)) byteTokens.removeAll() } newTokens.append(token) } } return newTokens } } class FuseDecoder: Decoder { public required init(config: Config) { } func decode(tokens: [String]) -> [String] { [tokens.joined(separator: "")] } } class StripDecoder: Decoder { let content: String let start: Int let stop: Int public required init(config: Config) { guard let content = config.content.string() else { fatalError("Incorrect StripDecoder configuration: can't parse `content`.") } guard let start = config.start.integer() else { fatalError("Incorrect StripDecoder configuration: can't parse `start`.") } guard let stop = config.stop.integer() else { fatalError("Incorrect StripDecoder configuration: can't parse `stop`.") } self.content = content self.start = start self.stop = stop } func decode(tokens: [String]) -> [String] { tokens.map { token in token.trimmingFromStart(upto: start).trimmingFromEnd(upto: stop) } } } class MetaspaceDecoder: Decoder { let addPrefixSpace: Bool let replacement: String public required init(config: Config) { addPrefixSpace = config.addPrefixSpace.boolean(or: false) replacement = config.replacement.string(or: "_") } func decode(tokens: [String]) -> [String] { var replaced = tokens.map { token in token.replacingOccurrences(of: replacement, with: " ") } if addPrefixSpace, replaced.first?.starts(with: " ") ?? false { replaced[0].removeFirst() } return replaced } } /// We could use firstIndex(where:), lastIndex(where:) for possibly better efficiency (and do both ends at once) public extension String { func trimmingFromStart(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto, result.first == character { result.removeFirst() trimmed += 1 } return result } func trimmingFromEnd(character: Character = " ", upto: Int) -> String { var result = self var trimmed = 0 while trimmed < upto, result.last == character { result.removeLast() trimmed += 1 } return result } }