bindings/python/src/utils/pretokenization.rs (254 lines of code) (raw):
use tokenizers as tk;
use pyo3::exceptions;
use pyo3::prelude::*;
use pyo3::types::*;
use super::{
DestroyPtr, PyNormalizedString, PyNormalizedStringRefMut, RefMutContainer, RefMutGuard,
};
use crate::encoding::PyEncoding;
use crate::error::ToPyResult;
use crate::token::PyToken;
use tk::{OffsetReferential, OffsetType, Offsets, PreTokenizedString, Token};
fn split(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResult<()> {
if !func.is_callable() {
Err(exceptions::PyTypeError::new_err(
"`split` expect a callable with the signature: \
`fn(index: int, normalized: NormalizedString) -> List[NormalizedString]`",
))
} else {
ToPyResult(pretok.split(|i, normalized| {
let output = func.call((i, PyNormalizedString::from(normalized)), None)?;
Ok(output
.extract::<Vec<PyNormalizedString>>()?
.into_iter()
.map(tk::NormalizedString::from))
}))
.into()
}
}
fn normalize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResult<()> {
if !func.is_callable() {
Err(exceptions::PyTypeError::new_err(
"`normalize` expect a callable with the signature: \
`fn(normalized: NormalizedString)`",
))
} else {
ToPyResult(pretok.normalize(|normalized| {
let norm = PyNormalizedStringRefMut::new(normalized);
func.call((norm.get().clone(),), None)?;
Ok(())
}))
.into()
}
}
fn tokenize(pretok: &mut PreTokenizedString, func: &Bound<'_, PyAny>) -> PyResult<()> {
if !func.is_callable() {
Err(exceptions::PyTypeError::new_err(
"`tokenize` expect a callable with the signature: \
`fn(str) -> List[Token]`",
))
} else {
ToPyResult(pretok.tokenize(|normalized| {
let output = func.call((normalized.get(),), None)?;
Ok(output
.extract::<Bound<PyList>>()?
.into_iter()
.map(|obj| Ok(Token::from(obj.extract::<PyToken>()?)))
.collect::<PyResult<Vec<_>>>()?)
}))
.into()
}
}
/// This is an enum
#[derive(Clone)]
pub struct PyOffsetReferential(OffsetReferential);
impl FromPyObject<'_> for PyOffsetReferential {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<String>()?;
Ok(Self(match s.as_ref() {
"original" => Ok(OffsetReferential::Original),
"normalized" => Ok(OffsetReferential::Normalized),
_ => Err(exceptions::PyValueError::new_err(
"Wrong value for OffsetReferential, expected one of `original, normalized`",
)),
}?))
}
}
#[derive(Clone)]
pub struct PyOffsetType(OffsetType);
impl FromPyObject<'_> for PyOffsetType {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let s = obj.extract::<String>()?;
Ok(Self(match s.as_ref() {
"byte" => Ok(OffsetType::Byte),
"char" => Ok(OffsetType::Char),
_ => Err(exceptions::PyValueError::new_err(
"Wrong value for OffsetType, expected one of `byte, char`",
)),
}?))
}
}
type PySplit = (String, Offsets, Option<Vec<PyToken>>);
fn get_splits(
pretok: &PreTokenizedString,
offset_referential: PyOffsetReferential,
offset_type: PyOffsetType,
) -> Vec<PySplit> {
pretok
.get_splits(offset_referential.0, offset_type.0)
.into_iter()
.map(|(s, o, t)| {
(
s.to_owned(),
o,
t.as_ref()
.map(|tokens| tokens.iter().map(|t| t.clone().into()).collect()),
)
})
.collect()
}
fn to_encoding(
pretok: &PreTokenizedString,
type_id: u32,
word_idx: Option<u32>,
) -> PyResult<PyEncoding> {
Ok(ToPyResult(
pretok
.clone()
.into_encoding(word_idx, type_id, tk::OffsetType::Char),
)
.into_py()?
.into())
}
/// PreTokenizedString
///
/// Wrapper over a string, that provides a way to normalize, pre-tokenize, tokenize the
/// underlying string, while keeping track of the alignment information (offsets).
///
/// The PreTokenizedString manages what we call `splits`. Each split represents a substring
/// which is a subpart of the original string, with the relevant offsets and tokens.
///
/// When calling one of the methods used to modify the PreTokenizedString (namely one of
/// `split`, `normalize` or `tokenize), only the `splits` that don't have any associated
/// tokens will get modified.
///
/// Args:
/// sequence: str:
/// The string sequence used to initialize this PreTokenizedString
#[pyclass(module = "tokenizers", name = "PreTokenizedString")]
pub struct PyPreTokenizedString {
pub(crate) pretok: tk::PreTokenizedString,
}
impl From<PreTokenizedString> for PyPreTokenizedString {
fn from(pretok: PreTokenizedString) -> Self {
Self { pretok }
}
}
impl From<PyPreTokenizedString> for PreTokenizedString {
fn from(pretok: PyPreTokenizedString) -> Self {
pretok.pretok
}
}
#[pymethods]
impl PyPreTokenizedString {
#[new]
#[pyo3(text_signature = "(self, sequence)")]
fn new(s: &str) -> Self {
PreTokenizedString::from(s).into()
}
/// Split the PreTokenizedString using the given `func`
///
/// Args:
/// func: Callable[[index, NormalizedString], List[NormalizedString]]:
/// The function used to split each underlying split.
/// It is expected to return a list of `NormalizedString`, that represent the new
/// splits. If the given `NormalizedString` does not need any splitting, we can
/// just return it directly.
/// In order for the offsets to be tracked accurately, any returned `NormalizedString`
/// should come from calling either `.split` or `.slice` on the received one.
#[pyo3(text_signature = "(self, func)")]
fn split(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
split(&mut self.pretok, func)
}
/// Normalize each split of the `PreTokenizedString` using the given `func`
///
/// Args:
/// func: Callable[[NormalizedString], None]:
/// The function used to normalize each underlying split. This function
/// does not need to return anything, just calling the methods on the provided
/// NormalizedString allow its modification.
#[pyo3(text_signature = "(self, func)")]
fn normalize(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
normalize(&mut self.pretok, func)
}
/// Tokenize each split of the `PreTokenizedString` using the given `func`
///
/// Args:
/// func: Callable[[str], List[Token]]:
/// The function used to tokenize each underlying split. This function must return
/// a list of Token generated from the input str.
#[pyo3(text_signature = "(self, func)")]
fn tokenize(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
tokenize(&mut self.pretok, func)
}
/// Return an Encoding generated from this PreTokenizedString
///
/// Args:
/// type_id: int = 0:
/// The type_id to be used on the generated Encoding.
///
/// word_idx: Optional[int] = None:
/// An optional word index to be used for each token of this Encoding. If provided,
/// all the word indices in the generated Encoding will use this value, instead
/// of the one automatically tracked during pre-tokenization.
///
/// Returns:
/// An Encoding
#[pyo3(signature = (type_id = 0, word_idx = None))]
#[pyo3(text_signature = "(self, type_id=0, word_idx=None)")]
fn to_encoding(&self, type_id: u32, word_idx: Option<u32>) -> PyResult<PyEncoding> {
to_encoding(&self.pretok, type_id, word_idx)
}
/// Get the splits currently managed by the PreTokenizedString
///
/// Args:
/// offset_referential: :obj:`str`
/// Whether the returned splits should have offsets expressed relative
/// to the original string, or the normalized one. choices: "original", "normalized".
///
/// offset_type: :obj:`str`
/// Whether the returned splits should have offsets expressed in bytes or chars.
/// When slicing an str, we usually want to use chars, which is the default value.
/// Now in some cases it might be interesting to get these offsets expressed in bytes,
/// so it is possible to change this here.
/// choices: "char", "bytes"
///
/// Returns
/// A list of splits
#[pyo3(signature = (
offset_referential = PyOffsetReferential(OffsetReferential::Original),
offset_type = PyOffsetType(OffsetType::Char)
))]
#[pyo3(text_signature = "(self, offset_referential=\"original\", offset_type=\"char\")")]
fn get_splits(
&self,
offset_referential: PyOffsetReferential,
offset_type: PyOffsetType,
) -> Vec<PySplit> {
get_splits(&self.pretok, offset_referential, offset_type)
}
}
#[pyclass(module = "tokenizers", name = "PreTokenizedString")]
#[derive(Clone)]
pub struct PyPreTokenizedStringRefMut {
inner: RefMutContainer<PreTokenizedString>,
}
impl DestroyPtr for PyPreTokenizedStringRefMut {
fn destroy(&mut self) {
self.inner.destroy();
}
}
impl PyPreTokenizedStringRefMut {
pub fn new(pretok: &mut tk::PreTokenizedString) -> RefMutGuard<'_, Self> {
// SAFETY: This is safe because we return a RefMutGuard here.
// The compiler will make sure the &mut stays valid as necessary.
RefMutGuard::new(Self {
inner: RefMutContainer::new(pretok),
})
}
pub fn destroyed_error() -> PyErr {
exceptions::PyException::new_err(
"Cannot use a PreTokenizedStringRefMut outside `pre_tokenize`",
)
}
}
#[pymethods]
impl PyPreTokenizedStringRefMut {
fn split(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
self.inner
.map_mut(|pretok| split(pretok, func))
.ok_or_else(PyPreTokenizedStringRefMut::destroyed_error)?
}
fn normalize(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
self.inner
.map_mut(|pretok| normalize(pretok, func))
.ok_or_else(PyPreTokenizedStringRefMut::destroyed_error)?
}
fn tokenize(&mut self, func: &Bound<'_, PyAny>) -> PyResult<()> {
self.inner
.map_mut(|pretok| tokenize(pretok, func))
.ok_or_else(PyPreTokenizedStringRefMut::destroyed_error)?
}
#[pyo3(signature = (type_id = 0, word_idx = None))]
fn to_encoding(&self, type_id: u32, word_idx: Option<u32>) -> PyResult<PyEncoding> {
self.inner
.map(|pretok| to_encoding(pretok, type_id, word_idx))
.ok_or_else(PyPreTokenizedStringRefMut::destroyed_error)?
}
#[pyo3(signature = (
offset_referential = PyOffsetReferential(OffsetReferential::Original),
offset_type = PyOffsetType(OffsetType::Char)
))]
fn get_splits(
&self,
offset_referential: PyOffsetReferential,
offset_type: PyOffsetType,
) -> PyResult<Vec<PySplit>> {
self.inner
.map(|pretok| get_splits(pretok, offset_referential, offset_type))
.ok_or_else(PyPreTokenizedStringRefMut::destroyed_error)
}
}