Sources/Tokenizers/Normalizer.swift (268 lines of code) (raw):
//
// Normalizer.swift
//
//
// Created by Pedro Cuenca on 17/7/23.
//
import Foundation
import Hub
public protocol Normalizer {
func normalize(text: String) -> String
func callAsFunction(text: String) -> String
init(config: Config)
}
extension Normalizer {
func callAsFunction(text: String) -> String {
normalize(text: text)
}
}
enum NormalizerType: String {
case Sequence
case Prepend
case Replace
case Lowercase
case NFD
case NFC
case NFKD
case NFKC
case Bert
case BertNormalizer
case Precompiled
case StripAccents
case Strip
case Unknown = ""
}
struct NormalizerFactory {
static func fromConfig(config: Config?) -> Normalizer? {
guard let config else { return nil }
guard let typeName = config.type.string() else { return nil }
let type = NormalizerType(rawValue: typeName)
switch type {
case .Sequence: return NormalizerSequence(config: config)
case .Prepend: return PrependNormalizer(config: config)
case .Replace: return ReplaceNormalizer(config: config)
case .Lowercase: return LowercaseNormalizer(config: config)
case .NFD: return NFDNormalizer(config: config)
case .NFC: return NFCNormalizer(config: config)
case .NFKD: return NFKDNormalizer(config: config)
case .NFKC: return NFKCNormalizer(config: config)
case .Bert, .BertNormalizer: return BertNormalizer(config: config)
case .Precompiled: return PrecompiledNormalizer(config: config)
case .StripAccents: return StripAccentsNormalizer(config: config)
case .Strip: return StripNormalizer(config: config)
default: fatalError("Unsupported Normalizer type: \(typeName)")
}
}
}
class NormalizerSequence: Normalizer {
let normalizers: [Normalizer]
public required init(config: Config) {
guard let configs = config.normalizers.array() else {
fatalError("No normalizers in Sequence")
}
normalizers = configs.compactMap { NormalizerFactory.fromConfig(config: $0) }
}
public func normalize(text: String) -> String {
normalizers.reduce(text) { current, normalizer in
normalizer(text: current)
}
}
}
class PrependNormalizer: Normalizer {
let prepend: String
public required init(config: Config) {
prepend = config.prepend.string(or: "")
}
public func normalize(text: String) -> String {
prepend + text
}
}
class ReplaceNormalizer: Normalizer {
let pattern: StringReplacePattern?
public required init(config: Config) {
pattern = StringReplacePattern.from(config: config)
}
public func normalize(text: String) -> String {
guard let pattern else { return text }
return pattern.replace(text)
}
}
class LowercaseNormalizer: Normalizer {
public required init(config: Config) { }
public func normalize(text: String) -> String {
text.lowercased()
}
}
class NFDNormalizer: Normalizer {
public required init(config: Config) { }
public func normalize(text: String) -> String {
text.decomposedStringWithCanonicalMapping
}
}
class NFCNormalizer: Normalizer {
public required init(config: Config) { }
public func normalize(text: String) -> String {
text.precomposedStringWithCanonicalMapping
}
}
class NFKDNormalizer: Normalizer {
required init(config: Config) { }
func normalize(text: String) -> String {
text.decomposedStringWithCompatibilityMapping
}
}
class NFKCNormalizer: Normalizer {
required init(config: Config) { }
func normalize(text: String) -> String {
text.precomposedStringWithCompatibilityMapping
}
}
class BertNormalizer: Normalizer {
let shouldCleanText: Bool
let shouldHandleChineseChars: Bool
let shouldStripAccents: Bool
let shouldLowercase: Bool
required init(config: Config) {
shouldCleanText = config.cleanText.boolean(or: true)
shouldHandleChineseChars = config.handleChineseChars.boolean(or: true)
shouldLowercase = config.lowercase.boolean(or: true)
shouldStripAccents = config.stripAccents.boolean(or: shouldLowercase)
}
func normalize(text: String) -> String {
var output = text
if shouldCleanText {
output = cleanText(text: output)
}
if shouldHandleChineseChars {
output = handleChineseChars(text: output)
}
if shouldStripAccents {
output = stripAccents(text: output)
}
if shouldLowercase {
output = output.lowercased()
}
return output
}
private func cleanText(text: String) -> String {
text.map { c in
guard let scalar = c.unicodeScalars.first,
scalar.value != 0x0,
scalar.value != 0xFFFD,
!isControl(scalar)
else { return "\(c)" }
// Replace whitespace: \t, \n, \r
if scalar.value == 0x009 || scalar.value == 0x00A || scalar.value == 0x000D {
return " "
} else {
return "\(c)"
}
}
.joined()
}
private func isControl(_ c: UnicodeScalar) -> Bool {
if c.value == 0x009 || c.value == 0x00A || c.value == 0x000D {
// Except \t, \n, \r that will be spaces.
false
} else {
// https://unicode.org/reports/tr44/#GC_Values_Table
// Other Cc | Cf | Cs | Co | Cn
isOther(c.properties.generalCategory)
}
}
private func isOther(_ c: Unicode.GeneralCategory) -> Bool {
c == .control || c == .format || c == .surrogate || c == .privateUse || c == .unassigned
}
private func handleChineseChars(text: String) -> String {
text.map { c in
if let scalar = c.unicodeScalars.first, Utils.isChineseChar(scalar) {
" \(c) "
} else {
"\(c)"
}
}
.joined()
}
private func stripAccents(text: String) -> String {
// This might be the same as `text.folding(options: .diacriticInsensitive, locale: nil)`
String(text.decomposedStringWithCanonicalMapping.unicodeScalars.filter { scalar in
!(scalar.value >= 0x0300 && scalar.value <= 0x036F)
})
}
}
class PrecompiledNormalizer: Normalizer {
// TODO: use `precompiledCharsmap` (base64-encoded string) from the configuration
required init(config: Config) { }
func normalize(text: String) -> String {
// TODO: This is a simplified implementation.
// - The following comments also apply here:
// https://github.com/xenova/transformers.js/blob/main/src/tokenizers.js#L2237-L2247
// - For a proper implementation, see:
// https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/normalizers/precompiled.rs#L36
var output = ""
var hasFullwidthTilde = false
for scalar in text.unicodeScalars {
switch scalar.value {
case 0x0001...0x0008, 0x000B, 0x000E...0x001F, 0x007F, 0x008F, 0x009F:
// Non-printing control characters
output.append("")
case 0x0009, 0x000A, 0x000C, 0x000D, 0x1680, 0x200B...0x200F, 0x2028, 0x2029, 0x2581,
0xFEFF, 0xFFFD:
// Separators
output.append(" ")
case 0xFF5E:
hasFullwidthTilde = true
fallthrough
default:
output.append(Character(scalar))
}
}
if hasFullwidthTilde {
return
output
.split(by: "\u{FF5E}")
.map { $0.precomposedStringWithCompatibilityMapping }
.joined(separator: "\u{FF5E}")
} else {
return output.precomposedStringWithCompatibilityMapping
}
}
}
class StripAccentsNormalizer: Normalizer {
required init(config: Config) { }
func normalize(text: String) -> String {
text.precomposedStringWithCompatibilityMapping
}
}
class StripNormalizer: Normalizer {
let leftStrip: Bool
let rightStrip: Bool
required init(config: Config) {
leftStrip = config.stripLeft.boolean(or: true)
rightStrip = config.stripRight.boolean(or: true)
}
func normalize(text: String) -> String {
var result = text
if leftStrip {
result = String(result.drop(while: { $0.isWhitespace }))
}
if rightStrip {
result = String(result.reversed().drop(while: { $0.isWhitespace }).reversed())
}
return result
}
}
enum StringReplacePattern {
case regexp(regexp: NSRegularExpression, replacement: String)
case string(pattern: String, replacement: String)
}
extension StringReplacePattern {
func replace(_ text: String) -> String {
switch self {
case let .regexp(regexp, replacement):
let range = NSRange(text.startIndex..., in: text)
let replaced = regexp.stringByReplacingMatches(
in: text, options: [], range: range, withTemplate: replacement
)
return replaced
case let .string(toReplace, replacement):
return text.replacingOccurrences(of: toReplace, with: replacement)
}
}
}
extension StringReplacePattern {
static func from(config: Config) -> StringReplacePattern? {
guard let replacement = config.content.string() else { return nil }
if let pattern = config.pattern.String.string() {
return StringReplacePattern.string(pattern: pattern, replacement: replacement)
}
if let pattern = config.pattern.Regex.string() {
guard let regexp = try? NSRegularExpression(pattern: pattern, options: []) else {
fatalError("Cannot build regexp from \(pattern)")
}
return StringReplacePattern.regexp(regexp: regexp, replacement: replacement)
}
return nil
}
}