Sources/Tokenizers/Trie.swift (76 lines of code) (raw):

// // Trie.swift // // // Created by Pedro Cuenca on 20240112. // Copyright © 2024 Hugging Face. All rights reserved. // import Foundation public struct Trie<T: Hashable> { public typealias Node = TrieNode<T> var root: Node public init(root: Node? = nil) { self.root = root ?? Node() } } public extension Trie { func insert(_ element: any Sequence<T>) { var node = root for item in element { if let child = node.children[item] { node = child } else { let child = Node() node.children[item] = child node = child } } node.isLeaf = true } func append(contentsOf container: any Sequence<any Sequence<T>>) { for t in container { insert(t) } } /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an array func commonPrefixSearch(_ text: any Sequence<T>) -> [[T]] { var node = root var seqs: [[T]] = [] var seq: [T] = [] for item in text { seq.append(item) guard let child = node.children[item] else { return seqs } node = child if node.isLeaf { seqs.append(seq) } } return seqs } /// Find all leaf nodes that share a common prefix with the input sequence (usually a text) /// Returns an iterator func commonPrefixSearchIterator(_ text: any Sequence<T>) -> LeavesWithCommonPrefixIterator<T> { LeavesWithCommonPrefixIterator(node: root, text: text) } } public extension Trie { /// Only used for testing, could migrate to collection func get(_ element: any Sequence<T>) -> Node? { var node = root for item in element { guard let child = node.children[item] else { return nil } node = child } return node } } // TODO: maybe store the scores here if it's helpful? public class TrieNode<T: Hashable> { var isLeaf: Bool = false var children: [T: TrieNode] = [:] } public struct LeavesWithCommonPrefixIterator<T: Hashable>: Sequence, IteratorProtocol { var node: TrieNode<T> var text: any Sequence<T> var seq: [T] = [] lazy var iterator = text.makeIterator() as any IteratorProtocol<T> public mutating func next() -> [T]? { while true { guard let item = iterator.next() else { return nil } seq.append(item) guard let child = node.children[item] else { return nil } node = child if node.isLeaf { return seq } } } }