Sources/Tokenizers/PostProcessor.swift (148 lines of code) (raw):
//
// PostProcessor.swift
//
//
// Created by Pedro Cuenca on 17/7/23.
//
import Foundation
import Hub
public protocol PostProcessor {
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
func callAsFunction(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool) -> [String]
init(config: Config)
}
extension PostProcessor {
func callAsFunction(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
postProcess(tokens: tokens, tokensPair: tokensPair, addSpecialTokens: addSpecialTokens)
}
}
enum PostProcessorType: String {
case TemplateProcessing
case ByteLevel
case RobertaProcessing
case BertProcessing
case Sequence
}
struct PostProcessorFactory {
static func fromConfig(config: Config?) -> PostProcessor? {
guard let config else { return nil }
guard let typeName = config.type.string() else { return nil }
let type = PostProcessorType(rawValue: typeName)
switch type {
case .TemplateProcessing: return TemplateProcessing(config: config)
case .ByteLevel: return ByteLevelPostProcessor(config: config)
case .RobertaProcessing: return RobertaProcessing(config: config)
case .BertProcessing: return BertProcessing(config: config)
case .Sequence: return SequenceProcessing(config: config)
default: fatalError("Unsupported PostProcessor type: \(typeName)")
}
}
}
class TemplateProcessing: PostProcessor {
let single: [Config]
let pair: [Config]
public required init(config: Config) {
guard let single = config.single.array() else { fatalError("Missing `single` processor configuration") }
guard let pair = config.pair.array() else { fatalError("Missing `pair` processor configuration") }
self.single = single
self.pair = pair
}
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] {
let config = tokensPair == nil ? single : pair
var toReturn: [String] = []
for item in config {
if let id = item.SpecialToken.id.string() {
if addSpecialTokens {
toReturn.append(id)
}
} else if item.Sequence.id.string() == "A" {
toReturn += tokens
} else if item.Sequence.id.string() == "B" {
toReturn += tokensPair!
}
}
return toReturn
}
}
class ByteLevelPostProcessor: PostProcessor {
public required init(config: Config) { }
func postProcess(tokens: [String], tokensPair: [String]? = nil, addSpecialTokens: Bool = true) -> [String] { tokens }
}
class RobertaProcessing: PostProcessor {
private let sep: (UInt, String)
private let cls: (UInt, String)
/// Trim all remaining space, or leave one space character if `addPrefixSpace` is `true`.
private let trimOffset: Bool
/// Keep one space character on each side. Depends on `trimOffsets` being `true`.
private let addPrefixSpace: Bool
public required init(config: Config) {
guard let sep = config.sep.token() else { fatalError("Missing `sep` processor configuration") }
guard let cls = config.cls.token() else { fatalError("Missing `cls` processor configuration") }
self.sep = sep
self.cls = cls
trimOffset = config.trimOffset.boolean(or: true)
addPrefixSpace = config.addPrefixSpace.boolean(or: true)
}
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
var outTokens = tokens
var tokensPair = tokensPair
if trimOffset {
if addPrefixSpace {
outTokens = outTokens.map { trimExtraSpaces(token: $0) }
tokensPair = tokensPair?.map { trimExtraSpaces(token: $0) }
} else {
outTokens = outTokens.map { $0.trimmingCharacters(in: .whitespaces) }
tokensPair = tokensPair?.map { $0.trimmingCharacters(in: .whitespaces) }
}
}
outTokens = [cls.1] + outTokens + [sep.1]
if let tokensPair, !tokensPair.isEmpty {
// Yes, it adds another `sep`.
// https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/roberta/hub_interface.py#L58-L65
outTokens += [sep.1] + tokensPair + [sep.1]
}
return outTokens
}
/// Some tokens need one space around them
/// https://github.com/huggingface/tokenizers/blob/main/tokenizers/src/pre_tokenizers/byte_level.rs#L203-L235
private func trimExtraSpaces(token: String) -> String {
let prefixOffset = findPrefixIndex(text: token)
let suffixOffset = findSuffixIndex(text: token)
let prefixIndex = token.index(token.startIndex, offsetBy: prefixOffset)
let suffixIndex = token.index(token.startIndex, offsetBy: token.count - suffixOffset)
return String(token[prefixIndex..<suffixIndex])
}
private func findPrefixIndex(text: String) -> Int {
guard !text.isEmpty, text.first!.isWhitespace else { return 0 }
return text.prefix(while: { $0.isWhitespace }).count - 1
}
private func findSuffixIndex(text: String) -> Int {
guard !text.isEmpty, text.last!.isWhitespace else { return 0 }
return text.reversed().prefix(while: { $0.isWhitespace }).count - 1
}
}
class BertProcessing: PostProcessor {
private let sep: (UInt, String)
private let cls: (UInt, String)
public required init(config: Config) {
guard let sep = config.sep.token() else { fatalError("Missing `sep` processor configuration") }
guard let cls = config.cls.token() else { fatalError("Missing `cls` processor configuration") }
self.sep = sep
self.cls = cls
}
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
guard addSpecialTokens else { return tokens + (tokensPair ?? []) }
var outTokens = [cls.1] + tokens + [sep.1]
if let tokensPair, !tokensPair.isEmpty {
outTokens += tokensPair + [sep.1]
}
return outTokens
}
}
class SequenceProcessing: PostProcessor {
private let processors: [PostProcessor]
public required init(config: Config) {
guard let processorConfigs = config.processors.array() else {
fatalError("Missing `processors` configuration")
}
processors = processorConfigs.compactMap { PostProcessorFactory.fromConfig(config: $0) }
}
func postProcess(tokens: [String], tokensPair: [String]?, addSpecialTokens: Bool = true) -> [String] {
var currentTokens = tokens
var currentTokensPair = tokensPair
for processor in processors {
let processed = processor.postProcess(tokens: currentTokens, tokensPair: currentTokensPair, addSpecialTokens: addSpecialTokens)
currentTokens = processed
currentTokensPair = nil // After the first processor, we no longer have a separate pair
}
return currentTokens
}
}