def main()

in training/flax/run_long_form_transcription.py [0:0]


def main():
    # 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your JAX/Flax versions.
    send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")

    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if "tensorboard" in training_args.report_to:
        if has_tensorboard and jax.process_index() == 0:
            try:
                from flax.metrics.tensorboard import SummaryWriter

                summary_writer = SummaryWriter(log_dir=Path(os.path.join(training_args.output_dir, "runs")))
            except ImportError as ie:
                has_tensorboard = False
                logger.warning(
                    f"Unable to display metrics through TensorBoard because some packages are not installed: {ie}"
                )
        else:
            logger.warning(
                "Unable to display metrics through TensorBoard because the package is"
                " not installed: Please run `pip install tensorboard` to enable."
            )

    # Enable wandb only on the master node
    has_wandb = is_wandb_available()
    if "wandb" in training_args.report_to:
        if has_wandb and jax.process_index() == 0:
            import wandb as wandb_logger

            # Set up wandb run
            wandb_logger.init(
                project=data_args.wandb_project,
                name=data_args.wandb_name,
                job_type=data_args.wandb_job_type,
                dir=data_args.wandb_dir,
                save_code=data_args.save_code_to_wandb,
            )
        else:
            logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")

    # 2. Setup logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    # Set the verbosity to info of the Transformers logger.
    # We only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    logger.info("Evaluation parameters %s", training_args)

    if model_args.compilation_cache:
        cc.initialize_cache(os.path.join(model_args.cache_dir, "jax_cache"))

    # 3. Load dataset
    raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()

    # Convert lists of dataset names/configs/splits to a dict
    # names: "librispeech_asr+gigaspeech", configs: "all+l", splits: "validation.clean+validation"
    # -> [{"name: "librispeech_asr": "config": "all", "split": "validation.clean"}, {"name: "gigaspeech": "config": "l", "split": "validation"}
    dataset_names_dict = convert_dataset_str_to_list(
        data_args.dataset_name,
        data_args.dataset_config_name,
        splits=data_args.dataset_split_name,
        text_column_names=data_args.text_column_name,
    )

    # load multiple eval sets
    for dataset_dict in dataset_names_dict:
        # Clean-up the dataset name for pretty logging
        # ("distil-whisper/librispeech_asr", "validation.clean") -> "librispeech_asr/validation-clean"
        pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
        raw_datasets[pretty_name] = load_dataset(
            dataset_dict["name"],
            dataset_dict["config"],
            split=dataset_dict["split"],
            cache_dir=data_args.dataset_cache_dir,
            use_auth_token=True if model_args.use_auth_token else None,
            streaming=data_args.streaming,
        )
        if dataset_dict["text_column_name"] not in list(raw_datasets[pretty_name].features.keys()):
            raise ValueError(
                f"--text column name {dataset_dict['text_column_name']} not found in the evaluation "
                f"dataset {dataset_dict['name']}. Ensure `text_column_name` is set to the correct column "
                f"for the target text. Should be one of {' '.join(list(raw_datasets[pretty_name].features.keys()))}"
            )
        if dataset_dict["text_column_name"] != "text":
            raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
                dataset_dict["text_column_name"], "text"
            )

    # Streaming mode robust way of obtaining the features
    raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
    audio_column_name = data_args.audio_column_name

    if audio_column_name not in raw_datasets_features:
        raise ValueError(
            f"--audio_column_name '{audio_column_name}' not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
            " the correct audio column - one of"
            f" {', '.join(raw_datasets_features)}."
        )

    for split in raw_datasets:
        raw_datasets[split] = raw_datasets[split].remove_columns(
            set(raw_datasets[split].features.keys()) - {audio_column_name, "text"}
        )

    # 5. Load pretrained model, tokenizer, and feature extractor
    pipeline = FlaxWhisperPipeline(
        model_args.model_name_or_path,
        dtype=getattr(jnp, model_args.dtype),
        max_length=training_args.generation_max_length,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        subfolder=model_args.subfolder,
        # use_scan=model_args.load_with_scan,  # Model might have (erroneously) been saved with scan still enabled
    )

    if pipeline.model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    # disable scan if necessary (makes the inference step faster)
    if model_args.load_with_scan:
        pipeline.model.disable_scan()  # to disable scan in the nn.Module
        pipeline.params = pipeline.model.convert_scan_to_unroll(
            pipeline.params
        )  # to convert the scan params to unrolled

    # 6. Possibly evaluate on a subset of data
    if data_args.max_eval_samples is not None:
        for split in raw_datasets:
            raw_datasets[split] = (
                raw_datasets[split].take(data_args.max_eval_samples)
                if data_args.streaming
                else raw_datasets[split].select(range(data_args.max_eval_samples))
            )

    # 8. Compute WER Metrics
    normalizer = EnglishTextNormalizer(pipeline.tokenizer.english_spelling_normalizer)

    def compute_metrics(pred_str, label_str, ngram_degree=5):
        # normalize everything and compute the WER
        norm_pred_str = [normalizer(pred).replace(".", "") for pred in pred_str]
        norm_label_str = [normalizer(label) for label in label_str]
        # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
        pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
        # filtering step to only evaluate the samples that correspond to non-zero normalized references:
        norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]

        wer_output = process_words(norm_label_str, norm_pred_str, wer_default, wer_default)
        wer_norm = 100 * wer_output.wer
        ier_norm = 100 * wer_output.insertions / sum([len(ref) for ref in wer_output.references])
        ser_norm = 100 * wer_output.substitutions / sum([len(ref) for ref in wer_output.references])
        der_norm = 100 * wer_output.deletions / sum([len(ref) for ref in wer_output.references])

        all_ngrams = list(ngrams(" ".join(norm_pred_str).split(), ngram_degree))
        repeated_ngrams = len(all_ngrams) - len(set(all_ngrams))

        return (
            {"wer": wer_norm, "ier": ier_norm, "ser": ser_norm, "der": der_norm, "repeated_ngrams": repeated_ngrams},
            pred_str,
            label_str,
            norm_pred_str,
            norm_label_str,
        )

    # Store some constants
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()
    num_beams = (
        training_args.generation_num_beams
        if training_args.generation_num_beams is not None
        else pipeline.model.config.num_beams
    )

    generation_config = pipeline.model.generation_config
    if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
        # We need to set the language and task ids for previously multilingual checkpoints - for now we hardcode this to English
        language = "English"
        task = "transcribe"
    else:
        language = None
        task = None

    # pre-compile the model so that we don't count it in our eval
    logger.info("Pre-compiling the generate call...")
    random_inputs = {"input_features": np.ones((eval_batch_size, 80, 2 * pipeline.model.config.max_source_positions))}
    pipeline.forward(
        random_inputs,
        batch_size=eval_batch_size,
        language=language,
        task=task,
        return_timestamps=model_args.return_timestamps,
        num_beams=num_beams,
        length_penalty=model_args.length_penalty,
        do_sample=model_args.do_sample,
        top_k=model_args.top_k,
        temperature=model_args.temperature,
    )

    def eval_step(split="eval"):
        # ======================== Evaluating ==============================
        eval_preds = []
        eval_labels = []
        eval_audios = []
        eval_start = time.time()

        for sample in tqdm(raw_datasets[split], desc=f"Evaluating {split}..."):
            # Model forward
            label_str = sample["text"]
            if data_args.log_audio:
                eval_audios.append(sample["audio"])

            pred_str = pipeline(
                sample["audio"],
                batch_size=eval_batch_size,
                language=language,
                task=task,
                chunk_length_s=model_args.chunk_length_s,
                return_timestamps=model_args.return_timestamps,
                num_beams=num_beams,
                length_penalty=model_args.length_penalty,
                do_sample=model_args.do_sample,
                top_k=model_args.top_k,
                temperature=model_args.temperature,
            )
            eval_preds.append(pred_str["text"])
            eval_labels.append(label_str)

        eval_time = time.time() - eval_start

        wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
            eval_preds, eval_labels, ngram_degree=data_args.ngram_degree
        )
        wer_desc = " ".join([f"{split} {key}: {value} |" for key, value in wer_metric.items()])

        # Print metrics to stdout
        logger.info(wer_desc)

        # Save metrics to tensorboard
        if has_tensorboard and jax.process_index() == 0 and "tensorboard" in training_args.report_to:
            write_metric(summary_writer, wer_metric, prefix=split)

        # Save metrics to wandb
        if has_wandb and jax.process_index() == 0 and "wandb" in training_args.report_to:
            write_wandb_metric(wandb_logger, wer_metric, eval_time, prefix=split)
            if data_args.log_predictions:
                write_wandb_pred(
                    wandb_logger, eval_audios, pred_str, label_str, norm_pred_str, norm_label_str, prefix=split
                )

    logger.info("***** Running Eval *****")
    logger.info("  Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
    logger.info(f"  Total eval batch size (w. parallel & distributed) = {eval_batch_size}")
    logger.info(f"  Beam size = {num_beams}")
    if num_beams > 1:
        logger.info(f"  Length penalty size = {model_args.length_penalty}")
    logger.info(f"  Do sample = {model_args.do_sample}")
    if model_args.do_sample:
        logger.info(f"  Top k = {model_args.top_k}")
        logger.info(f"  Temperature = {model_args.temperature}")

    for split in raw_datasets:
        eval_step(split=split)