def main()

in kyutai/run_eval.py [0:0]


def main(args):
    (
        mimi,
        tokenizer,
        _lm,
        lm_gen,
        padding_token_id,
        audio_silence_prefix_seconds,
        audio_delay_seconds,
    ) = load_model(args.model_id)

    mimi_frame_size = mimi.frame_size

    def benchmark(batch):
        # Load audio inputs
        audios = [torch.from_numpy(audio["array"]) for audio in batch["audio"]]
        sample_rates = [ex["sampling_rate"] for ex in batch["audio"]]

        batch["audio_length_s"] = [
            len(audio) / batch["audio"][0]["sampling_rate"] for audio in audios
        ]
        minibatch_size = len(audios)

        # Start timing
        start_time = time.time()

        padded_batch = get_padded_batch(
            audios,
            sample_rates,
            before_padding=audio_silence_prefix_seconds,
            after_padding=audio_delay_seconds,
            frame_size=mimi_frame_size,
        )
        padded_batch = padded_batch.to(args.device).float()

        bsz = padded_batch.shape[0]

        text_tokens_acc = []

        with mimi.streaming(bsz), lm_gen.streaming(bsz):
            for offset in range(0, padded_batch.shape[-1], mimi.frame_size):
                audio_chunk = padded_batch[:, offset : offset + mimi.frame_size].cuda()
                tokens = mimi.encode(audio_chunk[:, None, :])
                text_tokens = lm_gen.step(tokens)
                text_tokens_acc.append(text_tokens)

        pred_tokens = torch.concat(text_tokens_acc, axis=-1).squeeze(dim=1)
        pred_tokens = torch.unbind(pred_tokens, dim=0)

        pred_text = [
            tokenizer.decode(t[t > padding_token_id].cpu().numpy().tolist())
            for t in pred_tokens
        ]

        # End timing
        runtime = time.time() - start_time

        # normalize by minibatch size since we want the per-sample time
        batch["transcription_time_s"] = minibatch_size * [runtime / minibatch_size]

        # normalize transcriptions with English normalizer
        batch["predictions"] = [data_utils.normalizer(pred) for pred in pred_text]
        batch["references"] = batch["norm_text"]
        return batch

    if args.warmup_steps is not None:
        warmup_dataset = data_utils.load_data(args)
        warmup_dataset = data_utils.prepare_data(warmup_dataset)

        num_warmup_samples = args.warmup_steps * args.batch_size
        if args.streaming:
            warmup_dataset = warmup_dataset.take(num_warmup_samples)
        else:
            warmup_dataset = warmup_dataset.select(
                range(min(num_warmup_samples, len(warmup_dataset)))
            )
        warmup_dataset = iter(
            warmup_dataset.map(benchmark, batch_size=args.batch_size, batched=True)
        )

        for _ in tqdm(warmup_dataset, desc="Warming up..."):
            continue

    dataset = data_utils.load_data(args)
    dataset = data_utils.prepare_data(dataset)

    if args.max_eval_samples is not None and args.max_eval_samples > 0:
        print(f"Subsampling dataset to first {args.max_eval_samples} samples!")
        if args.streaming:
            dataset = dataset.take(args.max_eval_samples)
        else:
            dataset = dataset.select(range(min(args.max_eval_samples, len(dataset))))

    dataset = dataset.map(
        benchmark,
        batch_size=args.batch_size,
        batched=True,
        remove_columns=["audio"],
    )

    all_results = {
        "audio_length_s": [],
        "transcription_time_s": [],
        "predictions": [],
        "references": [],
    }
    result_iter = iter(dataset)
    for result in tqdm(result_iter, desc="Samples..."):
        for key in all_results:
            all_results[key].append(result[key])

    # Write manifest results (WER and RTFX)
    manifest_path = data_utils.write_manifest(
        all_results["references"],
        all_results["predictions"],
        args.model_id,
        args.dataset_path,
        args.dataset,
        args.split,
        audio_length=all_results["audio_length_s"],
        transcription_time=all_results["transcription_time_s"],
    )
    print("Results saved at path:", os.path.abspath(manifest_path))

    wer = wer_metric.compute(
        references=all_results["references"], predictions=all_results["predictions"]
    )
    wer = round(100 * wer, 2)
    rtfx = round(
        sum(all_results["audio_length_s"]) / sum(all_results["transcription_time_s"]), 2
    )
    print("WER:", wer, "%", "RTFx:", rtfx)