bindings/node/src/tasks/tokenizer.rs (111 lines of code) (raw):

extern crate tokenizers as tk; use crate::encoding::*; use crate::tokenizer::Tokenizer; use napi::bindgen_prelude::*; use tk::tokenizer::{EncodeInput, Encoding}; pub struct EncodeTask<'s> { pub tokenizer: Tokenizer, pub input: Option<EncodeInput<'s>>, pub add_special_tokens: bool, } impl Task for EncodeTask<'static> { type Output = Encoding; type JsValue = JsEncoding; fn compute(&mut self) -> Result<Self::Output> { self .tokenizer .tokenizer .read() .unwrap() .encode_char_offsets( self .input .take() .ok_or(Error::from_reason("No provided input"))?, self.add_special_tokens, ) .map_err(|e| Error::from_reason(format!("{}", e))) } fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> { Ok(JsEncoding { encoding: Some(output), }) } } pub struct DecodeTask { pub tokenizer: Tokenizer, pub ids: Vec<u32>, pub skip_special_tokens: bool, } impl Task for DecodeTask { type Output = String; type JsValue = String; fn compute(&mut self) -> Result<Self::Output> { self .tokenizer .tokenizer .read() .unwrap() .decode(&self.ids, self.skip_special_tokens) .map_err(|e| Error::from_reason(format!("{}", e))) } fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> { Ok(output) } } pub struct EncodeBatchTask<'s> { pub tokenizer: Tokenizer, pub inputs: Option<Vec<EncodeInput<'s>>>, pub add_special_tokens: bool, } impl Task for EncodeBatchTask<'static> { type Output = Vec<Encoding>; type JsValue = Vec<JsEncoding>; fn compute(&mut self) -> Result<Self::Output> { self .tokenizer .tokenizer .read() .unwrap() .encode_batch_char_offsets( self .inputs .take() .ok_or(Error::from_reason("No provided input"))?, self.add_special_tokens, ) .map_err(|e| Error::from_reason(format!("{}", e))) } fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> { Ok( output .into_iter() .map(|encoding| JsEncoding { encoding: Some(encoding), }) .collect(), ) } } pub struct DecodeBatchTask { pub tokenizer: Tokenizer, pub ids: Vec<Vec<u32>>, pub skip_special_tokens: bool, } impl Task for DecodeBatchTask { type Output = Vec<String>; type JsValue = Vec<String>; fn compute(&mut self) -> Result<Self::Output> { let ids: Vec<_> = self.ids.iter().map(|s| s.as_slice()).collect(); self .tokenizer .tokenizer .read() .unwrap() .decode_batch(&ids, self.skip_special_tokens) .map_err(|e| Error::from_reason(format!("{}", e))) } fn resolve(&mut self, _env: Env, output: Self::Output) -> Result<Self::JsValue> { Ok(output) } }