def main()

in tensorrtllm/run_eval.py [0:0]


def main(args):
    asr_model = WhisperTRTLLM(engine_dir=args.model_id)

    def benchmark(batch, min_new_tokens=None):
        # Load audio inputs
        max_duration, sample_rate = 30, 16000
        audios_origin = [audio["array"].astype(np.float32) for audio in batch["audio"]]
        minibatch_size = len(audios_origin)
        audios, audio_index = [], []

        chunk_length = 25
        overlap_length = 5 
        for i, audio in enumerate(audios_origin):
            if len(audio) > max_duration * sample_rate:
                audio_chunks = chunk_audio(audio, chunk_length, overlap_length, sample_rate)
                for chunk in audio_chunks:
                    audios.append(chunk)
                    audio_index.append(i)
            else:
                audios.append(audio)
                audio_index.append(i)
        audios = [torch.from_numpy(audio) for audio in audios]

        # START TIMING
        start_time = time.time()
        longest_duration = int(sample_rate * max_duration)

        features = [
            log_mel_spectrogram(wave,
                                asr_model.n_mels,
                                padding=longest_duration - wave.shape[-1],
                                device='cuda').unsqueeze(0)
            for wave in audios
        ]

        features_input_lengths = torch.tensor([f.shape[2] for f in features],
                                              dtype=torch.int32,
                                              device='cuda')

        texts_origin = asr_model.process_batch(features, features_input_lengths, num_threads=4)

        texts = []
        for i in range(minibatch_size):
            text_chunks = []
            for j in range(len(texts_origin)):
                if audio_index[j] == i:
                    text_chunks.append(texts_origin[j])
            
            if len(text_chunks) > 1:
                merged_text = text_chunks[0]
                for t in text_chunks[1:]:
                    lcs = longest_common_substring(merged_text, t)
                    merged_text += t[len(lcs):]
                    
                texts.append(merged_text)
            else:
                texts.append(text_chunks[0])
        # END TIMING
        runtime = time.time() - start_time

        print(f"Batch size: {minibatch_size}, Time taken: {runtime:.2f} s, texts_origin_len: {len(texts_origin)}, texts_len: {len(texts)}")
        # normalize by minibatch size since we want the per-sample time
        batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]

        # normalize transcriptions with English normalizer
        batch["predictions"] = [data_utils.normalizer(pred) for pred in texts]
        batch["references"] = batch["norm_text"]
        return batch

    if args.warmup_steps is not None:
        dataset = data_utils.load_data(args)
        dataset = data_utils.prepare_data(dataset)

        num_warmup_samples = args.warmup_steps * args.batch_size
        if args.streaming:
            warmup_dataset = dataset.take(num_warmup_samples)
        else:
            warmup_dataset = dataset.select(range(min(num_warmup_samples, len(dataset))))
        warmup_dataset = iter(warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True, fn_kwargs={"min_new_tokens": args.max_new_tokens}))

        for _ in tqdm(warmup_dataset, desc="Warming up..."):
            continue

    dataset = data_utils.load_data(args)
    if args.max_eval_samples is not None and args.max_eval_samples > 0:
        print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
        if args.streaming:
            dataset = dataset.take(args.max_eval_samples)
        else:
            dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))
    dataset = data_utils.prepare_data(dataset)

    dataset = dataset.map(
        benchmark, batch_size=args.batch_size, batched=True, remove_columns=["audio"],
    )

    all_results = {
        "audio_length_s": [],
        "transcription_time_s": [],
        "predictions": [],
        "references": [],
    }
    result_iter = iter(dataset)
    for result in tqdm(result_iter, desc="Samples..."):
        for key in all_results:
            all_results[key].append(result[key])

    # Write manifest results (WER and RTFX)
    manifest_path = data_utils.write_manifest(
        all_results["references"],
        all_results["predictions"],
        args.model_id,
        args.dataset_path,
        args.dataset,
        args.split,
        audio_length=all_results["audio_length_s"],
        transcription_time=all_results["transcription_time_s"],
    )
    print("Results saved at path:", os.path.abspath(manifest_path))

    wer = wer_metric.compute(
        references=all_results["references"], predictions=all_results["predictions"]
    )
    wer = round(100 * wer, 2)
    rtfx = round(sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2)
    print("WER:", wer, "%", "RTFx:", rtfx)