in tokenizers/src/models/unigram/model.rs [240:329]
fn encode_optimized(&self, sentence: &str) -> Result<Vec<String>> {
// https://github.com/google/sentencepiece/blob/d48247191a6d50e469ed1a4a36e877befffd1851/src/unigram_model.cc#L600
#[derive(Debug, Clone)]
struct BestPathNode {
/// The vocab id. (maybe UNK)
id: usize,
/// The total score of the best path ending at this node.
best_path_score: f64,
/// The starting position (in utf-8) of this node. The entire best
/// path can be constructed by backtracking along this link.
starts_at: Option<usize>,
}
impl Default for BestPathNode {
fn default() -> Self {
Self {
id: 0,
best_path_score: 0.0,
starts_at: None,
}
}
}
let size = sentence.len();
let unk_score = self.min_score - K_UNK_PENALTY;
let mut best_path_ends_at = vec![BestPathNode::default(); size + 1];
let mut starts_at = 0;
while starts_at < size {
let best_path_score_till_here = best_path_ends_at[starts_at].best_path_score;
let mut has_single_node = false;
let mblen = sentence[starts_at..].chars().next().unwrap().len_utf8();
for tok_bytes in self
.trie
.common_prefix_search(sentence.bytes().skip(starts_at))
{
let key_pos = starts_at + tok_bytes.len();
let token: String = String::from_utf8(tok_bytes).unwrap();
let target_node = &mut best_path_ends_at[key_pos];
let length = key_pos - starts_at;
let id = self.token_to_ids.get(&token).unwrap();
let score = self.vocab.get(*id as usize).unwrap().1;
let candidate_best_path_score = score + best_path_score_till_here;
if target_node.starts_at.is_none()
|| candidate_best_path_score > target_node.best_path_score
{
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at);
target_node.id = *id as usize;
}
if !has_single_node && length == mblen {
has_single_node = true;
}
}
if !has_single_node {
let target_node = &mut best_path_ends_at[starts_at + mblen];
let candidate_best_path_score = unk_score + best_path_score_till_here;
if target_node.starts_at.is_none()
|| candidate_best_path_score > target_node.best_path_score
{
target_node.best_path_score = candidate_best_path_score;
target_node.starts_at = Some(starts_at);
target_node.id = self.unk_id.ok_or(UnigramError::MissingUnkId)?;
}
}
starts_at += mblen
}
let mut ends_at = size;
let mut results: Vec<String> = vec![];
let mut token = vec![];
while ends_at > 0 {
let node = &best_path_ends_at[ends_at];
let starts_at = node.starts_at.unwrap();
if self.fuse_unk && Some(node.id) == self.unk_id {
token.push(sentence[starts_at..ends_at].to_string());
} else {
if !token.is_empty() {
token.reverse();
results.push(token.concat());
token = vec![];
}
results.push(sentence[starts_at..ends_at].to_string());
}
ends_at = starts_at;
}
if !token.is_empty() {
token.reverse();
results.push(token.concat());
}
results.reverse();
Ok(results)
}