maga_transformer/pipeline/chatapi_format.py (38 lines of code) (raw):

import json from typing import List, Dict, Union, Any from transformers import PreTrainedTokenizerBase def _format_tokens(content_tokens: List[int], role_special_tokens: Any) -> List[int]: content_tokens = role_special_tokens.token_ids + content_tokens + role_special_tokens.eos_token_ids return content_tokens # from modeling_baichuan.py def encode_chatapi(messages: List[Dict[str, str]], special_tokens: Any, tokenizer: PreTrainedTokenizerBase) -> List[int]: max_input_tokens = 2 ** 32 # int max, maybe support max_history_len in generate_config total_input: List[int] = [] round_input: List[int] = [] system_input = '' for i, message in enumerate(messages[::-1]): content_tokens = tokenizer.encode(message['content']) if message['role'] == 'user': round_input = _format_tokens(content_tokens, special_tokens.user) + round_input if total_input and len(total_input) + len(round_input) > max_input_tokens: break else: total_input = round_input + total_input if len(total_input) >= max_input_tokens: break else: round_input = [] elif message['role'] == 'assistant': round_input = _format_tokens(content_tokens, special_tokens.assistant) + round_input elif message['role'] == 'system': if i != len(messages) - 1: raise Exception('system role must be 1st message') system_input = message['content'] else: raise ValueError(f"message role not supported yet: {message['role']}") if system_input: total_input = _format_tokens(tokenizer.encode(system_input), special_tokens.system) + total_input if special_tokens.bos_token_id != -1: total_input = [special_tokens.bos_token_id] + total_input total_input = total_input[-max_input_tokens:] # truncate left total_input += special_tokens.assistant.token_ids return total_input