maga_transformer/utils/word_util.py (84 lines of code) (raw):

import numpy as np import torch from typing import List, Union, Any def remove_padding_eos(token_ids: torch.Tensor, eos_token_id: int) -> torch.Tensor: # token_ids shape: [max_length] out_token_ids = token_ids.cpu().numpy() out_token_ids = out_token_ids[out_token_ids != eos_token_id].tolist() return torch.IntTensor(out_token_ids) def remove_padding_eos_for_list(token_ids_list: List[torch.Tensor], eos_token_id: int) -> List[torch.Tensor]: # token_ids shape: [sub batch of stream, max_length] return [remove_padding_eos(token_ids, eos_token_id) for token_ids in token_ids_list] def get_list_dim(origin: Any) -> int: def _get_dim_internal(x: Any) -> int: if not isinstance(x, list): return 0 if len(x) == 0: return 1 else: return _get_dim_internal(x[0]) + 1 return _get_dim_internal(origin) ''' input: words_list shape: [batch_size, word_num, word_token_size] output: [batch_size, 2, max_seq_length] ''' def to_word_list_format(words_list: List[List[List[int]]]): flat_ids = [] offsets = [] for words in words_list: item_flat_ids = [] item_offsets = [] for ids in words: if len(ids) == 0: continue item_flat_ids += ids item_offsets.append(len(ids)) flat_ids.append(np.array(item_flat_ids)) offsets.append(np.cumsum(np.array(item_offsets))) pad_to = max(1, max(len(ids) for ids in flat_ids)) for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) result = np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) # if result.shape[0] == 1: # result = result.squeeze(0) return np.ascontiguousarray(result) def get_stop_word_slices(stop_word_list: List[Union[str, List[int]]]) -> List[Union[str, List[int]]]: result: List[Union[str, List[int]]] = [] for stop_word in stop_word_list: result.append(stop_word) for i in range(1, len(stop_word)): result.append(stop_word[:-i]) return result def is_truncated(input_str: str, trunc_strs: List[str], is_streaming: bool): if len(input_str) > 0 and len(truncate_response_with_stop_words(input_str, trunc_strs, is_streaming)) != len(input_str): return True return False def truncate_response_with_stop_words(response: str, stop_word_strs: List[str], is_streaming: bool = True): if is_streaming: for stop_word in stop_word_strs: if stop_word and response.endswith(stop_word): response = response[:(-len(stop_word))] break else: min_index = len(response) for stop_word in stop_word_strs: if stop_word: index = response.find(stop_word) if index != -1 and index < min_index: min_index = index if min_index != len(response): response = response[:min_index] return response def truncate_token_with_stop_word_id(tokens: List[int], stop_word_ids: List[int]): for stop_word_id in stop_word_ids: if stop_word_id and tokens[-len(stop_word_id):] == stop_word_id: tokens = tokens[:(-len(stop_word_id))] break return tokens def match_stop_words(response: str, stop_word_strs: List[str]) -> bool: for stop_word in stop_word_strs: if stop_word and response.endswith(stop_word): return True return False # main if __name__ == "__main__": # word_list = [[20490, 25]] # stop_list = to_word_list_format([word_list]) # print(stop_list, stop_list.shape) stop_words = ['abc', '11123'] print(get_stop_word_slices(stop_words))