def main()

in training/flax/run_pt_long_form_transcription.py [0:0]


def main():
    # 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if "tensorboard" in training_args.report_to:
        if has_tensorboard:
            try:
                from torch.utils.tensorboard import SummaryWriter

                summary_writer = SummaryWriter(log_dir=os.path.join(training_args.output_dir, "runs"))
            except ImportError as ie:
                has_tensorboard = False
                logger.warning(
                    "Unable to display metrics through TensorBoard because some" f" package are not installed: {ie}"
                )
        else:
            logger.warning(
                "Unable to display metrics through TensorBoard because the package is"
                " not installed: Please run `pip install tensorboard` to enable."
            )

    # Enable wandb only on the master node
    has_wandb = is_wandb_available()
    if "wandb" in training_args.report_to:
        if has_wandb:
            import wandb as wandb_logger

            # Set up wandb run
            wandb_logger.init(
                project=data_args.wandb_project,
                name=data_args.wandb_name,
                job_type=data_args.wandb_job_type,
                dir=data_args.wandb_dir,
                save_code=data_args.save_code_to_wandb,
            )
        else:
            logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")

    # 2. Setup logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    # Set the verbosity to info of the Transformers logger.
    # We only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO)
    datasets.utils.logging.set_verbosity_warning()
    transformers.utils.logging.set_verbosity_info()

    logger.info("Evaluation parameters %s", training_args)

    # 3. Load dataset
    raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()

    # Convert lists of dataset names/configs/splits to a dict
    # names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
    # -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
    dataset_names_dict = convert_dataset_str_to_list(
        data_args.dataset_name,
        data_args.dataset_config_name,
        splits=data_args.dataset_split_name,
        text_column_names=data_args.text_column_name,
    )

    # load multiple eval sets
    for dataset_dict in dataset_names_dict:
        # Clean-up the dataset name for pretty logging
        # ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
        pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
        raw_datasets[pretty_name] = load_dataset(
            dataset_dict["name"],
            dataset_dict["config"],
            split=dataset_dict["split"],
            cache_dir=data_args.dataset_cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
            streaming=data_args.streaming,
        )
        if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
            raise ValueError(
                f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
                f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
                f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
            )
        if dataset_dict["text_column_name"] != "text":
            raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
                dataset_dict["text_column_name"], "text"
            )

    # Streaming mode robust way of obtaining the features
    raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
    audio_column_name = data_args.audio_column_name

    if audio_column_name not in raw_datasets_features:
        raise ValueError(
            f"--audio_column_name '{audio_column_name}' not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
            " the correct audio column - one of"
            f" {', '.join(raw_datasets_features)}."
        )

    for split in raw_datasets:
        raw_datasets[split] = raw_datasets[split].remove_columns(
            set(raw_datasets[split].features.keys()) - {audio_column_name, "text"}
        )

    if data_args.max_eval_samples is not None:
        for split in raw_datasets:
            raw_datasets[split] = (
                raw_datasets[split].take(data_args.max_eval_samples)
                if data_args.streaming
                else raw_datasets[split].select(range(data_args.max_eval_samples))
            )

    # Store some constants
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    num_beams = training_args.generation_num_beams if training_args.generation_num_beams is not None else 1

    model_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_auth_token": True if model_args.use_auth_token else None,
        "subfolder": model_args.subfolder,
    }

    # 5. Load pretrained model, tokenizer, and feature extractor
    pipe = pipeline(
        "automatic-speech-recognition",
        model_args.model_name_or_path,
        torch_dtype=getattr(torch, model_args.dtype),
        model_kwargs=model_kwargs,
        max_new_tokens=training_args.generation_max_length,
        batch_size=per_device_eval_batch_size,
        chunk_length_s=model_args.chunk_length_s,
        return_timestamps=model_args.return_timestamps,
        device="cuda:0" if torch.cuda.is_available() else "cpu",
    )

    if pipe.model.can_generate():
        if pipe.model.config.decoder_start_token_id is None:
            raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
        generate_kwargs = {
            "num_beams": num_beams,
            "length_penalty": model_args.length_penalty,
            "do_sample": model_args.do_sample,
            "top_k": model_args.top_k,
            "temperature": model_args.temperature,
        }
        if hasattr(pipe.model.generation_config, "is_multilingual") and pipe.model.generation_config.is_multilingual:
            generate_kwargs = generate_kwargs.update({"langauge": "English", "task": "transcribe"})
    else:
        generate_kwargs = None

    # 8. Load Metric
    whisper_tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-tiny.en")
    normalizer = EnglishTextNormalizer(whisper_tokenizer.english_spelling_normalizer)

    def compute_metrics(pred_str, label_str, ngram_degree=5):
        # normalize everything and re-compute the WER
        norm_pred_str = [normalizer(pred) for pred in pred_str]
        norm_label_str = [normalizer(label) for label in label_str]
        # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
        pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
        # filtering step to only evaluate the samples that correspond to non-zero normalized references:
        norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]

        wer_output = process_words(norm_label_str, norm_pred_str, wer_default, wer_default)
        wer_norm = 100 * wer_output.wer
        ier_norm = 100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references])
        ser_norm = 100 * wer_output.substitutions / sum([len(ref) for ref in wer_output.references])
        der_norm = 100 * wer_output.deletions / sum([len(ref) for ref in wer_output.references])

        all_ngrams = list(ngrams(" ".join(norm_pred_str).split(), ngram_degree))
        repeated_ngrams = len(all_ngrams) - len(set(all_ngrams))

        return (
            {"wer": wer_norm, "ier": ier_norm, "ser": ser_norm, "der": der_norm, "repeated_ngrams": repeated_ngrams},
            pred_str,
            label_str,
            norm_pred_str,
            norm_label_str,
        )

    def eval_step(split="eval"):
        # ======================== Evaluating ==============================
        eval_preds = []
        eval_labels = []
        eval_audios = []
        eval_start = time.time()

        for sample in tqdm(
            pipe(
                data(raw_datasets[split], log_audio=data_args.log_audio),
                generate_kwargs=generate_kwargs,
            ),
            desc=f"Evaluating {split}...",
        ):
            eval_preds.append(sample["text"])
            eval_labels.append(sample["reference"][0])
            if data_args.log_audio:
                eval_audios.append(sample["audio"][0])

        eval_time = time.time() - eval_start

        wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
            eval_preds, eval_labels, data_args.ngram_degree
        )
        wer_desc = " ".join([f"{split} {key}: {value} |" for key, value in wer_metric.items()])

        # Print metrics to stdout
        logger.info(wer_desc)

        # Save metrics to tensorboard
        if has_tensorboard and "tensorboard" in training_args.report_to:
            write_metric(summary_writer, wer_metric, prefix=split)

        # Save metrics to wandb
        if has_wandb and "wandb" in training_args.report_to:
            write_wandb_metric(wandb_logger, wer_metric, eval_time, prefix=split)
            if data_args.log_predictions:
                write_wandb_pred(
                    wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
                )

    logger.info("***** Running Eval *****")
    logger.info("  Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
    logger.info(f"  Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size}")
    if pipe.model.can_generate():
        logger.info(f"  Beam size = {num_beams}")
        if num_beams > 1:
            logger.info(f"  Length penalty size = {model_args.length_penalty}")
        logger.info(f"  Do sample = {model_args.do_sample}")
        if model_args.do_sample:
            logger.info(f"  Top k = {model_args.top_k}")
            logger.info(f"  Temperature = {model_args.temperature}")

    for split in raw_datasets:
        eval_step(split=split)