def process_batch()

in tensorrtllm/run_eval.py [0:0]


    def process_batch(self, mel, mel_input_lengths, num_threads=4, max_new_tokens=96):
        prompt_id = self.tokenizer.encode(
            self.text_prefix, allowed_special=self.tokenizer.special_tokens_set)
        prompt_id = torch.tensor(prompt_id)
        batch_size = len(mel)
        decoder_input_ids = prompt_id.repeat(batch_size, 1)

        with torch.no_grad():
            if isinstance(mel, list):
                mel = torch.stack([m.transpose(1, 2).type(torch.float16).squeeze(0) for m in mel])
            else:
                mel = mel.transpose(1, 2)

            num_threads = min(num_threads, batch_size)
            mel_batches = torch.split(mel, batch_size // num_threads)
            mel_input_lengths_batches = torch.split(mel_input_lengths, batch_size // num_threads)

            texts_list = []
            with ThreadPoolExecutor(max_workers=num_threads) as executor:
                futures = []
                for i, mel_batch in enumerate(mel_batches):
                    current_length = mel_batch.size(0)
                    futures.append(executor.submit(
                        self.process_single_batch,
                        mel_batch,
                        decoder_input_ids[:current_length],
                        mel_input_lengths_batches[i],
                        max_new_tokens
                    ))
                
                for future in futures:
                    texts_list.extend(future.result())
        
        return texts_list