def main()

in training/run_pseudo_labelling.py [0:0]


def main():
    # 1. Parse input arguments
    # We keep distinct sets of args, for cleaner separation of model/data/training related args
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # 2. Initialize the accelerator
    # We will let the accelerator handle device placement for us in this example
    # We simply have to specify the training precision and any trackers being used
    # We'll use the same dtype arguments as our JAX/Flax training script and convert
    # it to accelerate format
    if model_args.dtype == "float16":
        mixed_precision = "fp16"
        torch_dtype = torch.float16
    elif model_args.dtype == "bfloat16":
        mixed_precision = "bf16"
        torch_dtype = torch.bfloat16
    else:
        mixed_precision = "no"
        torch_dtype = torch.float32

    kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=7200))

    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        mixed_precision=mixed_precision,
        log_with=training_args.report_to,
        project_dir=training_args.output_dir,
        kwargs_handlers=[kwargs],
    )

    accelerator.init_trackers(project_name=data_args.wandb_project)

    # 3. Set-up basic logging
    # Create one log on every process with the configuration for debugging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    # Log a small summary on each proces
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )

    # Set the verbosity to info of the Transformers logger (on main process only)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()
    logger.info("Training/evaluation parameters %s", training_args)

    # 3. Load dataset
    raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
    token = model_args.token if model_args.token is not None else HfFolder().get_token()

    data_splits = data_args.dataset_split_name.split("+")
    for split in data_splits:
        with accelerator.main_process_first():
            raw_datasets[split] = load_dataset(
                data_args.dataset_name,
                data_args.dataset_config_name,
                split=split,
                cache_dir=data_args.dataset_cache_dir,
                token=token,
                streaming=data_args.streaming,
                num_proc=data_args.preprocessing_num_workers if not data_args.streaming else None,
            )

    if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
        raise ValueError(
            f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
            " the correct audio column - one of"
            f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
        )

    if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
        raise ValueError(
            f"--text_column_name {data_args.text_column_name} not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--text_column_name` to the"
            " correct text column - one of"
            f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
        )
    
    # 7. Load pretrained model, tokenizer, and feature extractor
    config = WhisperConfig.from_pretrained(
        (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=token,
    )
    feature_extractor = WhisperFeatureExtractor.from_pretrained(
        (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=token,
    )
    tokenizer = WhisperTokenizerFast.from_pretrained(
        (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        token=token,
    )
    processor = WhisperProcessor.from_pretrained(
        (model_args.processor_name if model_args.processor_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=token,
    )

    model = WhisperForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        subfolder=model_args.subfolder,
        token=token,
        low_cpu_mem_usage=True,
        torch_dtype=torch_dtype,
        attn_implementation=model_args.attn_implementation,
    )
    model.eval()

    if model.config.decoder_start_token_id is None:
        raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")

    return_timestamps = data_args.return_timestamps
    if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
        is_multilingual = True
        # We need to set the language and task ids for multilingual checkpoints
        tokenizer.set_prefix_tokens(
            language=data_args.language, task=data_args.task, predict_timestamps=return_timestamps
        )
    elif data_args.language is not None:
        raise ValueError(
            "Setting language token for an English-only checkpoint is not permitted. The language argument should "
            "only be set for multilingual checkpoints."
        )
    else:
        is_multilingual = False

    # 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
    # so we just need to set the correct target sampling rate.
    raw_datasets = raw_datasets.cast_column(
        data_args.audio_column_name,
        datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
    )

    # 7. Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
    max_label_length = (
        data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
    )
    audio_column_name = data_args.audio_column_name
    sampling_rate = feature_extractor.sampling_rate

    preprocessing_batch_size = data_args.preprocessing_batch_size
    num_workers = data_args.preprocessing_num_workers
    dataloader_num_workers = training_args.dataloader_num_workers

    text_column_name = data_args.text_column_name
    model_input_name = feature_extractor.model_input_names[0]
    id_column_name = data_args.id_column_name
    speaker_id_column_name = data_args.speaker_id_column_name
    normalizer = (
        BasicTextNormalizer()
        if data_args.language is not None
        else EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
    )

    timestamp_position = 3 if is_multilingual else 1
    decoder_prev_token_id = tokenizer.convert_tokens_to_ids("<|startofprev|>")
    decoder_eot_token_id = tokenizer.eos_token_id

    if data_args.max_samples_per_split is not None:
        for split in data_splits:
            raw_datasets[split] = (
                raw_datasets[split].take(data_args.max_samples_per_split)
                if data_args.streaming
                else raw_datasets[split].select(range(data_args.max_samples_per_split))
            )

    if speaker_id_column_name is not None:
        raw_datasets = raw_datasets.sort(speaker_id_column_name)

    def concatenate_dataset(batch):
        audio_arrays, texts, speaker_ids = [], [], []

        # skip corrupted samples
        for row in table_iter(batch.pa_table, batch_size=1):
            row = batch.formatter.format_row(row)
            try:
                sample_audio = row[audio_column_name]['array']
                sample_text = row[text_column_name]
                sample_speaker_id = row[speaker_id_column_name] if speaker_id_column_name else None
            except LibsndfileError:
                logger.warning(f"{row[id_column_name]} is corrupted! Skipping sample.")
                continue
            audio_arrays.append(sample_audio)
            texts.append(sample_text)
            speaker_ids.append(sample_speaker_id)

        # initialize concatenations
        concat_audio = [audio_arrays[0]]
        concat_text = [texts[0]]
        concat_speaker_id = [speaker_ids[0]]
        condition_on_prev = [0]

        for audio_array, text, speaker_id in zip(audio_arrays[1:], texts[1:], speaker_ids[1:]):
            is_same_speaker = speaker_id == concat_speaker_id[-1]
            is_concatenable = len(audio_array) + len(concat_audio[-1]) <= max_input_length 
            if is_same_speaker and is_concatenable:
                # inplace concatenation
                concat_audio[-1] = np.append(concat_audio[-1], audio_array)
                concat_text[-1] = concat_text[-1] + " " + text
            else:
                concat_audio.append(audio_array)
                concat_text.append(text)
                concat_speaker_id.append(speaker_id)
                condition_on_prev.append(1 if is_same_speaker else 0)   

        batch[audio_column_name] = [{"array": array, "sampling_rate": sampling_rate} for array in concat_audio]
        batch[text_column_name] = concat_text
        batch[id_column_name] = concat_speaker_id
        batch["condition_on_prev"] = condition_on_prev

        return batch

    raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
    if data_args.concatenate_audio and not data_args.streaming:
        with accelerator.main_process_first():
            raw_datasets = raw_datasets.map(
                concatenate_dataset,
                batched=True,
                batch_size=preprocessing_batch_size,
                num_proc=num_workers,
                remove_columns=set(raw_datasets_features)
                - {audio_column_name, text_column_name, id_column_name, "condition_on_prev"},
                desc="Concatenating dataset...",
            )

        raw_datasets = raw_datasets.cast_column(
            audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
        )
        pretty_name = data_args.dataset_name.split("/")[-1]

        def postprocess_ids(speaker_ids, indices):
            speaker_ids_formatted = []
            for speaker, idx in zip(speaker_ids, indices):
                formatted_idx = f"{pretty_name}-{speaker}-{idx}" if speaker is not None else f"{pretty_name}-{idx}"
                speaker_ids_formatted.append(formatted_idx)
            return {id_column_name: speaker_ids_formatted}
        
        with accelerator.main_process_first():
            raw_datasets = raw_datasets.map(
                postprocess_ids,
                input_columns=[id_column_name],
                with_indices=True,
                desc="Setting sample idxs...",
                batched=True,
                batch_size=preprocessing_batch_size,
                num_proc=num_workers,
            )
    elif data_args.concatenate_audio and data_args.streaming:
        raise ValueError(
            "Streaming mode is not yet compatible with concatenating audios to `max_duration_in_seconds`."
            "Either set `--streaming=False` and download the audios locally, or open an issue on the Distil-Whisper repo to request this feature."
        )

    def prepare_dataset(batch):
        # process audio
        sample = batch[audio_column_name]
        inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
        # process audio length
        batch[model_input_name] = inputs.get(model_input_name)[0]

        # process targets
        input_str = batch[text_column_name]
        batch["labels"] = tokenizer(input_str, max_length=max_label_length, truncation=True).input_ids
        return batch

    raw_datasets_features = list(next(iter(raw_datasets.values())).features.keys())
    file_ids_dataset = IterableDatasetDict() if data_args.streaming else DatasetDict()
    for split in raw_datasets:
        file_ids_dataset[split] = raw_datasets[split][id_column_name]
    if data_args.streaming:
        with accelerator.main_process_first():
            vectorized_datasets = raw_datasets.map(prepare_dataset, remove_columns=raw_datasets_features)
    else:
        with accelerator.main_process_first():
            vectorized_datasets = raw_datasets.map(
                prepare_dataset,
                remove_columns=raw_datasets_features,
                num_proc=num_workers,
                desc="preprocess dataset",
            )

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with `args.preprocessing_only` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step `args.preprocessing_only` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
        logger.info(f"Data preprocessing finished. Files cached at {cache}.")
        return

    if data_args.streaming and dataloader_num_workers > 0:
        logger.warning(
            "Using multiple dataloader num workers with streaming mode will result in different shards of "
            "data being transcribed in parallel. This is not advised if you want to preserve the order of the "
            "audio-text data."
        )

    # Handle the repository creation
    output_dir = training_args.output_dir
    if accelerator.is_main_process:
        if training_args.push_to_hub:
            if training_args.hub_model_id is None:
                repo_name = get_full_repo_name(
                    Path(output_dir).absolute().name,
                    token=training_args.hub_token,
                )
            else:
                repo_name = training_args.hub_model_id
            create_repo(repo_name, repo_type="dataset", exist_ok=True, token=training_args.hub_token)
            snapshot_download(repo_id=repo_name, repo_type="dataset", local_dir=output_dir, token=training_args.hub_token)

            # Ensure large txt files can be pushed to the Hub with git-lfs
            with open(os.path.join(output_dir, ".gitattributes"), "r+") as f:
                git_lfs_extensions = f.read()
                if "*.csv" not in git_lfs_extensions:
                    f.write("*.csv filter=lfs diff=lfs merge=lfs -text")

        elif output_dir is not None:
            # this is where we'll save our transcriptions
            os.makedirs(output_dir, exist_ok=True)

    accelerator.wait_for_everyone()

    # 8. Load Metric
    metric = evaluate.load("wer")

    def compute_metrics(preds, labels, file_ids):
        # replace padded labels by the padding token
        for idx in range(len(labels)):
            labels[idx][labels[idx] == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(preds, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # normalize everything and re-compute the WER
        norm_pred_str = [normalizer(pred) for pred in pred_str]
        norm_label_str = [normalizer(label) for label in label_str]
        # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
        pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
        file_ids = [file_ids[i] for i in range(len(file_ids)) if len(norm_label_str[i]) > 0]
        # filtering step to only evaluate the samples that correspond to non-zero normalized references:
        norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]

        wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)

        return {"wer": wer}, pred_str, label_str, norm_pred_str, norm_label_str, file_ids

    def filter_eot_tokens(preds):
        for idx in range(len(preds)):
            # remove the EOT tokens to get the 'true' token length
            token_ids = [token for token in preds[idx] if token != decoder_eot_token_id]
            token_ids = token_ids + [decoder_eot_token_id]
            preds[idx] = token_ids
        return preds

    # 12. Define Training Schedule
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)

    data_collator = DataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=model.config.decoder_start_token_id,  # <|startoftranscript|>
        input_padding="longest",
        target_padding="max_length",
        max_target_length=max_label_length,
    )

    # 14. Define generation arguments - we need to do this before we wrap the models in DDP
    # so that we can still access the configs
    num_beams = (
        training_args.generation_num_beams
        if training_args.generation_num_beams is not None
        else getattr(model.generation_config, "num_beams", 1)
    )

    gen_kwargs = {
        "max_length": max_label_length,
        "num_beams": num_beams,
        "return_timestamps": return_timestamps,
    }
    if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
        # forcing the language and task tokens helps multilingual models in their generations
        gen_kwargs.update(
            {
                "language": data_args.language,
                "task": data_args.task,
            }
        )
    # remove any preset forced decoder ids since these are deprecated
    model.generation_config.forced_decoder_ids = None
    model.config.forced_decoder_ids = None

    # 15. Prepare everything with accelerate
    model = accelerator.prepare(model)

    def eval_step_with_save(split="eval"):
        # ======================== Evaluating ==============================
        eval_preds = []
        eval_labels = []
        eval_ids = []
        pred_str = []
        eval_start = time.time()

        eval_loader = DataLoader(
            vectorized_datasets[split],
            batch_size=per_device_eval_batch_size,
            collate_fn=data_collator,
            num_workers=dataloader_num_workers,
            pin_memory=True,
        )
        file_loader = DataLoader(
            file_ids_dataset[split],
            batch_size=per_device_eval_batch_size * accelerator.num_processes,
            num_workers=dataloader_num_workers,
        )

        eval_loader = accelerator.prepare(eval_loader)
        batches = tqdm(eval_loader, desc=f"Evaluating {split}...", disable=not accelerator.is_local_main_process)

        # make the split name pretty for librispeech etc
        split = split.replace(".", "-").split("/")[-1]
        output_csv = os.path.join(output_dir, f"{split}-transcription.csv")

        for step, (batch, file_ids) in enumerate(zip(batches, file_loader)):
            # Generate predictions and pad to max generated length
            generate_fn = model.module.generate if accelerator.num_processes > 1 else model.generate
            generated_ids = generate_fn(batch["input_features"].to(dtype=torch_dtype), **gen_kwargs)
            generated_ids = accelerator.pad_across_processes(generated_ids, dim=1, pad_index=tokenizer.pad_token_id)
            # Gather all predictions and targets
            generated_ids, labels = accelerator.gather_for_metrics((generated_ids, batch["labels"]))
            eval_preds.extend(generated_ids.cpu().numpy())
            eval_labels.extend(labels.cpu().numpy())
            eval_ids.extend(file_ids)

            if step % training_args.logging_steps == 0 and step > 0:
                batches.write(f"Saving transcriptions for split {split} step {step}")
                accelerator.wait_for_everyone()
                pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
                pred_ids = filter_eot_tokens(pred_ids)
                pred_str.extend(
                    tokenizer.batch_decode(
                        pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps
                    )
                )
                csv_data = [[eval_ids[i], pred_str[i]] for i in range(len(eval_preds))]

                with open(output_csv, "w", encoding="UTF8", newline="") as f:
                    writer = csv.writer(f)
                    # write multiple rows
                    writer.writerow(["file_id", "whisper_transcript"])
                    writer.writerows(csv_data)

                if training_args.push_to_hub and accelerator.is_main_process:
                    upload_folder(
                        folder_path=output_dir,
                        repo_id=repo_name,
                        repo_type="dataset",
                        token=training_args.hub_token,
                        commit_message=f"Saving transcriptions for split {split} step {step}.",
                    )

        accelerator.wait_for_everyone()
        eval_time = time.time() - eval_start

        # compute WER metric for eval sets
        wer_desc = ""
        if "validation" in split or "test" in split:
            eval_preds = filter_eot_tokens(eval_preds)
            wer_metric, pred_str, label_str, norm_pred_str, norm_label_str, eval_ids = compute_metrics(
                eval_preds, eval_labels, eval_ids
            )
            wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
            # Save metrics + predictions
            log_metric(
                accelerator,
                metrics=wer_metric,
                train_time=eval_time,
                prefix=split,
            )
            log_pred(
                accelerator,
                pred_str,
                label_str,
                norm_pred_str,
                norm_label_str,
                prefix=split,
            )
        else:
            pred_ids = eval_preds[-(len(eval_preds) - len(pred_str)) :]
            pred_ids = filter_eot_tokens(pred_ids)
            pred_str.extend(
                tokenizer.batch_decode(pred_ids, skip_special_tokens=False, decode_with_timestamps=return_timestamps)
            )

        batches.write(f"Saving final transcriptions for split {split}.")
        csv_data = [[eval_ids[i], eval_preds[i]] for i in range(len(eval_preds))]
        with open(output_csv, "w", encoding="UTF8", newline="") as f:
            writer = csv.writer(f)
            # write multiple rows
            writer.writerow(["file_id", "whisper_transcript"])
            writer.writerows(csv_data)

        # Print metrics
        logger.info(wer_desc)

        if not data_args.streaming:
            raw_datasets[split] = raw_datasets[split].add_column("whisper_transcript", pred_str)
            raw_datasets[split] = raw_datasets[split].add_column("eval_preds", eval_preds)

            def add_concatenated_text(eval_preds, condition_on_prev):
                concatenated_prev = [None]
                for token_ids, condition in zip(eval_preds[:-1], condition_on_prev[1:]):
                    if condition is False:
                        concatenated_prev.append(None)
                    else:
                        prompt_ids = [token for token in token_ids if token != decoder_eot_token_id]
                        prompt_ids = [decoder_prev_token_id] + prompt_ids[timestamp_position:]
                        concatenated_prev.append(prompt_ids)
                return {"condition_on_prev": concatenated_prev}

            if data_args.concatenate_audio:
                with accelerator.main_process_first():
                    raw_datasets[split] = raw_datasets[split].map(
                        add_concatenated_text,
                        input_columns=["eval_preds", "condition_on_prev"],
                        remove_columns=["eval_preds"],
                        desc="Setting condition on prev...",
                        batched=True,
                        batch_size=preprocessing_batch_size,
                        num_proc=num_workers,
                    )

    logger.info("***** Running Labelling *****")
    logger.info("  Instantaneous batch size per device =" f" {training_args.per_device_eval_batch_size}")
    logger.info(
        f"  Total eval batch size (w. parallel & distributed) = {training_args.per_device_eval_batch_size * accelerator.num_processes}"
    )
    logger.info(f"  Predict labels with timestamps = {return_timestamps}")
    for split in data_splits:
        eval_step_with_save(split=split)
        accelerator.wait_for_everyone()
        if training_args.push_to_hub and accelerator.is_main_process:
            upload_folder(
                folder_path=output_dir,
                repo_id=repo_name,
                repo_type="dataset",
                token=training_args.hub_token,
                commit_message=f"Saving final transcriptions for split {split.replace('.', '-').split('/')[-1]}",
            )
    if not data_args.streaming and accelerator.is_main_process:
        raw_datasets.save_to_disk(output_dir, num_proc=num_workers)
        if training_args.push_to_hub:
            raw_datasets.push_to_hub(repo_name, token=training_args.hub_token, config_name=data_args.dataset_config_name)
    accelerator.end_training()