Sources/Tokenizers/TokenLattice.swift (101 lines of code) (raw):
//
// TokenLattice.swift
//
//
// Created by Pedro Cuenca on 20240117.
// Copyright © 2024 Hugging Face. All rights reserved.
//
/// Implements a TokenLattice to implement the Viterbi algorithm
/// We could make it generic so TokenLatticeNode stores an opaque type, but it's overkill right now.
/// Based on https://github.com/huggingface/tokenizers/blob/b58227c7f1ccf8b73ee2268354336da56d91e492/tokenizers/src/models/unigram/lattice.rs#L137
/// and https://github.com/xenova/transformers.js/blob/b07336d8f7ff57453cc164cc68aead2a79cbd57e/src/utils/data-structures.js#L269C28-L269C28
public struct TokenLattice {
let sentence: String
let bosTokenId: Int
let eosTokenId: Int
var nodes: [TokenLatticeNode] = []
var beginNodes: [[TokenLatticeNode]]
var endNodes: [[TokenLatticeNode]]
var count: Int { sentence.count }
init(sentence: String, bosTokenId: Int, eosTokenId: Int) {
self.sentence = sentence
self.bosTokenId = bosTokenId
self.eosTokenId = eosTokenId
beginNodes = Array(repeating: [], count: sentence.count + 1)
endNodes = Array(repeating: [], count: sentence.count + 1)
let bos = TokenLatticeNode(tokenId: bosTokenId, startOffset: 0, length: 0, score: 0)
let eos = TokenLatticeNode(tokenId: eosTokenId, startOffset: sentence.count, length: 0, score: 0)
nodes.append(bos)
nodes.append(eos)
beginNodes[sentence.count].append(eos)
endNodes[0].append(bos)
}
}
public extension TokenLattice {
/// Insert a new token into the node lattice.
///
/// - Parameters:
/// - startOffset: Starting position of the token in the sentence.
/// - length: Number of characters in the token.
/// - score: Token score.
/// - tokenId: Token id in the tokenizer.
mutating func insert(startOffset: Int, length: Int, score: Float, tokenId: Int) {
let node = TokenLatticeNode(tokenId: tokenId, startOffset: startOffset, length: length, score: score)
beginNodes[startOffset].append(node)
endNodes[startOffset + length].append(node)
nodes.append(node)
}
}
extension TokenLattice {
/// Implements the Viterbi algorithm to compute the most likely sequence of tokens.
/// It's unfortunate that it can't be lazy or cached as the node arrays are not immutable.
/// We could create another type that holds the nodes and use it as an immutable var in TokenLattice.
func viterbi() -> [TokenLatticeNode] {
for offset in 0...count {
guard beginNodes[offset].count > 0 else { return [] }
for rnode in beginNodes[offset] {
rnode.prev = nil
var bestScore: Float = 0
var bestNode: TokenLatticeNode?
for lnode in endNodes[offset] {
let score = lnode.backtraceScore + rnode.score
if bestNode == nil || score > bestScore {
bestNode = lnode.clone()
bestScore = score
}
}
if bestNode != nil {
rnode.prev = bestNode
rnode.backtraceScore = bestScore
}
}
}
let root = beginNodes[count][0]
guard let prev = root.prev else { return [] }
// TODO: the reference implementations have a few more clones here: verify
var result: [TokenLatticeNode] = []
var node = prev // .clone()
while node.prev != nil {
result.append(node.clone())
node = node.prev! // .clone()
}
return result.reversed()
}
/// Returns the substring of the sentence to be tokenized associated to the specified node
///
/// - Parameters:
/// - node: The node defining the token to be extracted
///
/// - Returns: A **Substring** – i.e., a reference to the original positions, not a copy of the characters.
func piece(_ node: TokenLatticeNode) -> any StringProtocol {
let start = sentence.index(sentence.startIndex, offsetBy: node.startOffset)
let end = sentence.index(start, offsetBy: node.length)
return sentence[start..<end]
}
}
public extension TokenLattice {
var tokens: [String] {
viterbi().map { String(piece($0)) }
}
var tokenIds: [Int] {
viterbi().map { $0.tokenId }
}
}
class TokenLatticeNode {
let tokenId: Int
let startOffset: Int
let length: Int
let score: Float
var prev: TokenLatticeNode?
var backtraceScore: Float = 0
init(tokenId: Int, startOffset: Int, length: Int, score: Float, prev: TokenLatticeNode? = nil, backtraceScore: Float = 0) {
self.tokenId = tokenId
self.startOffset = startOffset
self.length = length
self.score = score
self.prev = prev
self.backtraceScore = backtraceScore
}
}
extension TokenLatticeNode {
/// This is a reference type because structs can't contain references to the same type
/// We could implement NSCopying, but frankly I don't see the point
func clone() -> TokenLatticeNode {
TokenLatticeNode(tokenId: tokenId, startOffset: startOffset, length: length, score: score, prev: prev, backtraceScore: backtraceScore)
}
}
extension TokenLatticeNode: CustomStringConvertible {
var description: String {
"TokenLatticeNode(tokenId: \(tokenId), startOffset: \(startOffset), length: \(length), score: \(score), prev: \(prev != nil), backtraceScore: \(backtraceScore)"
}
}