def main()

in nemo_asr/run_eval.py [0:0]


def main(args):

    DATA_CACHE_DIR = os.path.join(os.getcwd(), "audio_cache")
    DATASET_NAME = args.dataset
    SPLIT_NAME = args.split

    CACHE_DIR = os.path.join(DATA_CACHE_DIR, DATASET_NAME, SPLIT_NAME)
    if not os.path.exists(CACHE_DIR):
        os.makedirs(CACHE_DIR)

    if args.device >= 0:
        device = torch.device(f"cuda:{args.device}")
        compute_dtype=torch.bfloat16
    else:
        device = torch.device("cpu")
        compute_dtype=torch.float32
        

    if args.model_id.endswith(".nemo"):
        asr_model = ASRModel.restore_from(args.model_id, map_location=device)
    else:
        asr_model = ASRModel.from_pretrained(args.model_id, map_location=device)  # type: ASRModel
    
    asr_model.to(compute_dtype)
    asr_model.eval()

    dataset = data_utils.load_data(args)

    def download_audio_files(batch):

        # download audio files and write the paths, transcriptions and durations to a manifest file
        audio_paths = []
        durations = []

        for id, sample in zip(batch["id"], batch["audio"]):

            # first step added here to make ID and wav filenames unique
            # several datasets like earnings22 have a hierarchical structure
            # for eg. earnings22/test/4432298/281.wav, earnings22/test/4450488/281.wav
            # lhotse uses the filename (281.wav) here as unique ID to create and name cuts
            # ref: https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/collation.py#L186
            id = id.replace('/', '_').removesuffix('.wav')

            audio_path = os.path.join(CACHE_DIR, f"{id}.wav")

            if "array" in sample:
                audio_array = np.float32(sample["array"])
                sample_rate = 16000

            elif "bytes" in sample: # added to be compatible with latest datasets library (3.x.x) that produces byte stream
                with io.BytesIO(sample["bytes"]) as audio_file:
                    audio_array, sample_rate = soundfile.read(audio_file, dtype="float32")

            else:
                raise ValueError("Sample must have either 'array' or 'bytes' key")

            if not os.path.exists(audio_path):
                os.makedirs(os.path.dirname(audio_path), exist_ok=True)
                soundfile.write(audio_path, audio_array, sample_rate)

            audio_paths.append(audio_path)
            durations.append(len(audio_array) / sample_rate)

        
        batch["references"] = batch["norm_text"]
        batch["audio_filepaths"] = audio_paths
        batch["durations"] = durations

        return batch


    if args.max_eval_samples is not None and args.max_eval_samples > 0:
        print(f"Subsampling dataset to first {args.max_eval_samples} samples !")
        dataset = dataset.take(args.max_eval_samples)

    dataset = data_utils.prepare_data(dataset)
    if asr_model.cfg.decoding.strategy != "beam":
        asr_model.cfg.decoding.strategy = "greedy_batch"
        asr_model.change_decoding_strategy(asr_model.cfg.decoding)
    
    # prepraing the offline dataset
    dataset = dataset.map(download_audio_files, batch_size=args.batch_size, batched=True, remove_columns=["audio"])

    # Write manifest from daraset batch using json and keys audio_filepath, duration, text

    all_data = {
        "audio_filepaths": [],
        "durations": [],
        "references": [],
    }

    data_itr = iter(dataset)
    for data in tqdm(data_itr, desc="Downloading Samples"):
        for key in all_data:
            all_data[key].append(data[key])
    
    # Sort audio_filepaths and references based on durations values
    sorted_indices = sorted(range(len(all_data["durations"])), key=lambda k: all_data["durations"][k], reverse=True)
    all_data["audio_filepaths"] = [all_data["audio_filepaths"][i] for i in sorted_indices]
    all_data["references"] = [all_data["references"][i] for i in sorted_indices]
    all_data["durations"] = [all_data["durations"][i] for i in sorted_indices]
    
    
    total_time = 0
    for _ in range(2): # warmup once and calculate rtf
        if _ == 0:
            audio_files = all_data["audio_filepaths"][:args.batch_size * 4] # warmup with 4 batches
        else:
            audio_files = all_data["audio_filepaths"]
        start_time = time.time()
        with torch.cuda.amp.autocast(enabled=False, dtype=compute_dtype), torch.inference_mode(), torch.no_grad():
            if 'canary' in args.model_id:
                transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, pnc='no', num_workers=1)
            else:
                transcriptions = asr_model.transcribe(audio_files, batch_size=args.batch_size, verbose=False, num_workers=1)
        end_time = time.time()
        if _ == 1:
            total_time += end_time - start_time
    total_time = total_time

    # normalize transcriptions with English normalizer
    if isinstance(transcriptions, tuple) and len(transcriptions) == 2:
        transcriptions = transcriptions[0]
    predictions = [data_utils.normalizer(pred.text) for pred in transcriptions]

    avg_time = total_time / len(all_data["audio_filepaths"])

    # Write manifest results (WER and RTFX)
    manifest_path = data_utils.write_manifest(
        all_data["references"],
        predictions,
        args.model_id,
        args.dataset_path,
        args.dataset,
        args.split,
        audio_length=all_data["durations"],
        transcription_time=[avg_time] * len(all_data["audio_filepaths"]),
    )

    print("Results saved at path:", os.path.abspath(manifest_path))

    wer = wer_metric.compute(references=all_data['references'], predictions=predictions)
    wer = round(100 * wer, 2)

    # transcription_time = sum(all_results["transcription_time"])
    audio_length = sum(all_data["durations"])
    rtfx = audio_length / total_time
    rtfx = round(rtfx, 2)

    print("RTFX:", rtfx)
    print("WER:", wer, "%")