maga_transformer/utils/tokenizer_utils.py (119 lines of code) (raw):
from typing import List, Optional, Tuple, Union
from transformers import (PreTrainedTokenizer,
PreTrainedTokenizerFast)
class DecodingState(object):
last_input_id_index: int
all_text: str
prev_tokens: Optional[List[str]]
prefix_offset: int = 0
read_offset: int = 0
def __init__(self):
self.last_input_id_index = 0
self.prev_tokens = None
self.prefix_offset = 0
self.read_offset = 0
self.all_text = ""
def update(self,
last_input_id_index: int,
prev_tokens: Optional[List[str]],
prefix_offset: int = 0, read_offset: int = 0):
self.last_input_id_index = last_input_id_index
self.prev_tokens = prev_tokens
self.prefix_offset = prefix_offset
self.read_offset = read_offset
def __str__(self):
return f"{self.__class__.__name__}(" + ", ".join([f"{k}={v!r}" for k, v in self.__dict__.items()]) + ")"
# Referenced from
# https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/tokenizer.py#L68
def _convert_tokens_to_string_with_added_encoders(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
skip_special_tokens: bool,
spaces_between_special_tokens: bool,
) -> str:
sub_texts = []
current_sub_text = []
all_special_tokens = set(tokenizer.all_special_tokens)
legacy_added_tokens = set(tokenizer._added_tokens_encoder.keys()) - set(tokenizer.all_special_tokens) | {
token for token in tokenizer.additional_special_tokens if tokenizer.convert_tokens_to_ids(token) >= tokenizer.vocab_size
}
for token in output_tokens:
if skip_special_tokens and token in all_special_tokens:
continue
if token in legacy_added_tokens:
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_text = tokenizer.convert_tokens_to_string(current_sub_text)
sub_texts.append(sub_text)
if spaces_between_special_tokens:
return " ".join(sub_texts)
else:
return "".join(sub_texts)
class IncrementDecodingUtils(object):
@staticmethod
def detokenize_incrementally(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
state: DecodingState,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
):
output_tokens, prefix_offset, read_offset = IncrementDecodingUtils._get_new_tokens(tokenizer, all_input_ids, state, skip_special_tokens)
prefix_text, new_text = IncrementDecodingUtils._convert_token_to_string(tokenizer, output_tokens, prefix_offset, read_offset, skip_special_tokens, spaces_between_special_tokens)
if len(new_text) > len(prefix_text) and not new_text.endswith("�"):
new_text = new_text[len(prefix_text):]
state.update(len(output_tokens), output_tokens, read_offset, len(output_tokens))
return new_text
else:
state.update(len(output_tokens), output_tokens, prefix_offset, read_offset)
return ""
@staticmethod
def _get_new_tokens(tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
all_input_ids: List[int],
state: DecodingState,
skip_special_tokens: bool = False) -> Tuple[List[str], int, int]:
# first in
if state.prev_tokens is None:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids, skip_special_tokens=skip_special_tokens)
prefix_offset = 0
read_offset = 0
output_tokens = new_tokens
else:
new_tokens = tokenizer.convert_ids_to_tokens(
all_input_ids[state.last_input_id_index: ], skip_special_tokens=skip_special_tokens)
prefix_offset = state.prefix_offset
read_offset = state.read_offset
output_tokens = state.prev_tokens + new_tokens
return output_tokens, prefix_offset, read_offset
@staticmethod
def _convert_token_to_string(
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
output_tokens: List[str],
prefix_offset: int,
read_offset: int,
skip_special_tokens: bool = False,
spaces_between_special_tokens: bool = True,
) -> Tuple[str, str]:
if tokenizer.is_fast or not tokenizer.get_added_vocab():
prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:read_offset])
new_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_offset:])
else:
prefix_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:read_offset],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
new_text = _convert_tokens_to_string_with_added_encoders(
tokenizer,
output_tokens[prefix_offset:],
skip_special_tokens=skip_special_tokens,
spaces_between_special_tokens=spaces_between_special_tokens,
)
return prefix_text, new_text