in tokenizers/src/processors/bert.rs [51:192]
fn process_encodings(
&self,
mut encodings: Vec<Encoding>,
add_special_tokens: bool,
) -> Result<Vec<Encoding>> {
if !add_special_tokens {
return Ok(encodings);
}
let encodings: Vec<Encoding> = encodings
.iter_mut()
.enumerate()
.map(|(i, encoding)| {
if i == 0 {
let ids = [&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
std::slice::from_ref(&self.cls.0),
encoding.get_tokens(),
std::slice::from_ref(&self.sep.0),
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets = [&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]].concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let sequence_ranges = AHashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let ids =
[&[self.cls.1], encoding.get_ids(), &[self.sep.1]].concat();
let type_ids = [&[0], encoding.get_type_ids(), &[0]].concat();
let tokens = [
std::slice::from_ref(&self.cls.0),
encoding.get_tokens(),
std::slice::from_ref(&self.sep.0),
]
.concat();
let words = [&[None], encoding.get_word_ids(), &[None]].concat();
let offsets =
[&[(0, 0)], encoding.get_offsets(), &[(0, 0)]].concat();
let special_tokens =
[&[1u32], &vec![0; encoding.get_ids().len()][..], &[1]]
.concat();
let attention_mask = vec![1; ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't
// contain the special tokens.
let sequence_ranges =
AHashMap::from_iter(vec![(0, 1..ids.len() - 1)]);
Encoding::new(
ids,
type_ids,
tokens,
words,
offsets,
special_tokens,
attention_mask,
vec![],
sequence_ranges,
)
})
.collect(),
sequence_ranges,
)
} else {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens =
[encoding.get_tokens(), std::slice::from_ref(&self.sep.0)].concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges shouldn't contain
// the special tokens.
let pair_sequence_ranges =
AHashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
encoding
.take_overflowing()
.into_iter()
.map(|encoding| {
let pair_ids = [encoding.get_ids(), &[self.sep.1]].concat();
let pair_type_ids = [encoding.get_type_ids(), &[1]].concat();
let pair_tokens =
[encoding.get_tokens(), std::slice::from_ref(&self.sep.0)]
.concat();
let pair_words = [encoding.get_word_ids(), &[None]].concat();
let pair_offsets = [encoding.get_offsets(), &[(0, 0)]].concat();
let pair_special_tokens =
[&vec![0u32; encoding.get_type_ids().len()][..], &[1]].concat();
let pair_attention_mask = vec![1; pair_ids.len()];
// For compatibility with `TemplateProcessing`, the sequence_ranges
// shouldn't contain the special tokens.
let pair_sequence_ranges =
AHashMap::from_iter(vec![(1, 0..pair_ids.len() - 1)]);
Encoding::new(
pair_ids,
pair_type_ids,
pair_tokens,
pair_words,
pair_offsets,
pair_special_tokens,
pair_attention_mask,
vec![],
pair_sequence_ranges,
)
})
.collect(),
pair_sequence_ranges,
)
}
})
.collect();
Ok(encodings)
}