in tensorrtllm/run_eval.py [0:0]
def process_batch(self, mel, mel_input_lengths, num_threads=4, max_new_tokens=96):
prompt_id = self.tokenizer.encode(
self.text_prefix, allowed_special=self.tokenizer.special_tokens_set)
prompt_id = torch.tensor(prompt_id)
batch_size = len(mel)
decoder_input_ids = prompt_id.repeat(batch_size, 1)
with torch.no_grad():
if isinstance(mel, list):
mel = torch.stack([m.transpose(1, 2).type(torch.float16).squeeze(0) for m in mel])
else:
mel = mel.transpose(1, 2)
num_threads = min(num_threads, batch_size)
mel_batches = torch.split(mel, batch_size // num_threads)
mel_input_lengths_batches = torch.split(mel_input_lengths, batch_size // num_threads)
texts_list = []
with ThreadPoolExecutor(max_workers=num_threads) as executor:
futures = []
for i, mel_batch in enumerate(mel_batches):
current_length = mel_batch.size(0)
futures.append(executor.submit(
self.process_single_batch,
mel_batch,
decoder_input_ids[:current_length],
mel_input_lengths_batches[i],
max_new_tokens
))
for future in futures:
texts_list.extend(future.result())
return texts_list