in tokenizers/src/models/unigram/trainer.rs [277:432]
fn prune_sentence_pieces(
&self,
model: &Unigram,
pieces: &[SentencePiece],
sentences: &[Sentence],
) -> Vec<SentencePiece> {
let mut always_keep = vec![true; pieces.len()];
let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
let bos_id = pieces.len() + 1;
let eos_id = pieces.len() + 2;
// First, segments the current sentencepieces to know
// how each sentencepiece is resegmented if this sentencepiece is removed
// from the vocabulary.
// To do so, we take the second best segmentation of sentencepiece[i].
// alternatives[i] stores the sequence of second best sentencepieces.
for (id, (token, _score)) in pieces.iter().enumerate() {
// Always keep unk.
if id == 0 {
always_keep[id] = false;
continue;
}
let mut lattice = Lattice::from(token, bos_id, eos_id);
model.populate_nodes(&mut lattice);
let nbests = lattice.nbest(2);
if nbests.len() == 1 {
always_keep[id] = true;
} else if nbests[0].len() >= 2 {
always_keep[id] = false;
} else if nbests[0].len() == 1 {
always_keep[id] = true;
for node in &nbests[1] {
let alt_id = node.borrow().id;
alternatives[id].push(alt_id);
}
}
}
// Second, segments all sentences to compute likelihood
// with a unigram language model. inverted[i] stores
// the set of sentence index where the sentencepieces[i] appears.
let chunk_size = std::cmp::max(sentences.len() / current_num_threads(), 1);
let indexed_sentences: Vec<(usize, &Sentence)> = sentences.iter().enumerate().collect();
let collected: (f64, Vec<f64>, Vec<Vec<usize>>) = indexed_sentences
.maybe_par_chunks(chunk_size)
.map(|enumerated_sentence_count_chunk| {
let mut vsum = 0.0;
let mut freq: Vec<f64> = vec![0.0; pieces.len()];
let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
for (i, (sentence, count)) in enumerated_sentence_count_chunk {
let mut lattice = Lattice::from(sentence, bos_id, eos_id);
model.populate_nodes(&mut lattice);
vsum += *count as f64;
for node_ref in lattice.viterbi() {
let id = node_ref.borrow().id;
freq[id] += *count as f64;
inverted[id].push(*i);
}
}
(vsum, freq, inverted)
})
.reduce(
|| (0.0, vec![0.0; pieces.len()], vec![Vec::new(); pieces.len()]),
|(vsum, freq, inverted), (lvsum, lfreq, linverted)| {
(
vsum + lvsum,
freq.iter()
.zip(lfreq)
.map(|(global_el, local_el)| global_el + local_el)
.collect(),
inverted
.iter()
.zip(linverted)
.map(|(global_el, local_el)| [&global_el[..], &local_el[..]].concat())
.collect(),
)
},
);
let (vsum, freq, inverted) = collected;
let sum: f64 = freq.iter().sum();
let logsum = sum.ln();
let mut candidates: Vec<(usize, f64)> = vec![];
let mut new_pieces: Vec<SentencePiece> = Vec::with_capacity(self.vocab_size as usize);
new_pieces.push(pieces[0].clone());
// Finally, computes how likely the LM likelihood is reduced if
// the sentencepiece[i] is removed from the vocabulary.
// Since the exact computation of loss is difficult, we compute the
// loss approximately by assuming that all sentencepiece[i] in the sentences
// are replaced with alternatives[i] when sentencepiece[i] is removed.
for (id, (token, score)) in pieces.iter().enumerate() {
if id == 0 {
continue;
}
if freq[id] == 0.0 && !always_keep[id] {
// not found in Viterbi path. Can remove this entry safely.
continue;
} else if alternatives[id].is_empty() {
// no alternatives. Keeps this entry.
new_pieces.push((token.to_string(), *score));
} else {
let mut f = 0.0; // the frequency of pieces[i];
for n in &inverted[id] {
let score = sentences[*n].1 as f64;
f += score;
}
// TODO: Temporary hack to avoid Nans.
if f == 0.0 || f.is_nan() {
// new_pieces.push((token.to_string(), *score));
continue;
}
f /= vsum; // normalizes by all sentence frequency.
let logprob_sp = freq[id].ln() - logsum;
// After removing the sentencepiece[i], its frequency freq[i] is
// re-assigned to alternatives.
// new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
// = current_sum + freq[i] (alternatives - 1)
let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln();
// The frequencies of alternatives are increased by freq[i].
let mut logprob_alt = 0.0;
for n in &alternatives[id] {
logprob_alt += (freq[*n] + freq[id]).ln() - logsum_alt;
}
// loss: the diff of likelihood after removing the sentencepieces[i].
let loss = f * (logprob_sp - logprob_alt);
if loss.is_nan() {
panic!("");
}
candidates.push((id, loss));
}
}
let desired_vocab_size: usize = (self.vocab_size as usize * 11) / 10; // * 1.1
let pruned_size: usize = ((pieces.len() as f64) * self.shrinking_factor) as usize;
let pruned_size = desired_vocab_size.max(pruned_size);
candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
for (id, _score) in candidates {
if new_pieces.len() == pruned_size {
break;
}
new_pieces.push(pieces[id].clone());
}
new_pieces
}