in crates/ratchet-models/src/whisper/logit_mutators/timestamp_rules.rs [15:96]
fn apply(
&self,
logits: Tensor,
tokenizer: &WhisperTokenizer,
tokens: Option<&Tensor>,
) -> anyhow::Result<Tensor> {
let nd_tokens = tokens.unwrap().clone().into_ndarray::<i32>();
let mut nd_logits = logits.into_ndarray::<f32>();
nd_logits
.slice_mut(s![.., tokenizer.notimestamps() as usize])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
for k in 0..nd_tokens.shape()[0] {
let sampled_tokens = nd_tokens.slice(s![k, self.sample_begin..]);
let sample_len = sampled_tokens.len();
let last_was_timestamp = !sampled_tokens.is_empty()
&& sampled_tokens[sample_len - 1] >= tokenizer.timestamp_begin();
let penultimate_was_timestamp = sampled_tokens.len() < 2
|| sampled_tokens[sample_len - 2] >= tokenizer.timestamp_begin();
if last_was_timestamp {
if penultimate_was_timestamp {
nd_logits
.slice_mut(s![k, tokenizer.timestamp_begin()..])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
} else {
nd_logits
.slice_mut(s![k, ..WhisperTokenizer::EOT])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
}
}
let timestamps = sampled_tokens
.iter()
.filter(|x| **x >= tokenizer.timestamp_begin())
.collect::<Vec<_>>();
if !timestamps.is_empty() {
// timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
// also force each segment to have a nonzero length, to prevent infinite looping
let timestamp_last = if last_was_timestamp && !penultimate_was_timestamp {
*timestamps[timestamps.len() - 1]
} else {
timestamps[timestamps.len() - 1] + 1
};
nd_logits
.slice_mut(s![k, tokenizer.timestamp_begin()..timestamp_last])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
}
}
if nd_tokens.shape()[1] == self.sample_begin {
// suppress generating non-timestamp tokens at the beginning
nd_logits
.slice_mut(s![.., ..tokenizer.timestamp_begin()])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
if self.max_initial_timestamp_index.is_some() {
let last_allowed = (tokenizer.timestamp_begin() as usize)
+ self.max_initial_timestamp_index.unwrap();
nd_logits
.slice_mut(s![.., last_allowed + 1..])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
}
}
let logprobs = nd_logits.log_softmax(1);
for _k in 0..nd_tokens.shape()[0] {
let timestamp_logprob = logprobs
.slice(s![.., tokenizer.timestamp_begin()..])
.logsumexp(1);
let text_logprobs = logprobs.slice(s![.., ..tokenizer.timestamp_begin()]);
let max_text_token_logprob = text_logprobs.max()?;
if timestamp_logprob > *max_text_token_logprob {
nd_logits
.slice_mut(s![.., ..tokenizer.timestamp_begin()])
.map_inplace(move |el| *el = f32::NEG_INFINITY);
}
}
Ok(Tensor::from(nd_logits))
}