def main()

in training/flax/run_distillation.py [0:0]


def main():
    # 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your JAX/Flax versions.
    send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")

    # 2. Define remote logging - do this early so that we get the full traceback on our remote logs
    # Enable tensorboard only on the master node
    has_tensorboard = is_tensorboard_available()
    if has_tensorboard:
        if jax.process_index() == 0:
            try:
                from flax.metrics.tensorboard import SummaryWriter

                summary_writer = SummaryWriter(log_dir=os.path.join(Path(training_args.output_dir), "runs"))
            except ImportError as ie:
                has_tensorboard = False
                logger.warning(
                    "Unable to display metrics through TensorBoard because some package" f" are not installed: {ie}"
                )
    else:
        logger.warning(
            "Unable to display metrics through TensorBoard because the package is not"
            " installed: Please run `pip install tensorboard` to enable."
        )

    # Enable wandb only on the master node
    has_wandb = is_wandb_available()
    if has_wandb:
        import wandb as wandb_logger

        # Set up wandb run
        if jax.process_index() == 0:
            wandb_logger.init(
                project=data_args.wandb_project,
                name=data_args.wandb_name,
                job_type=data_args.wandb_job_type,
                dir=data_args.wandb_dir,
                save_code=data_args.save_code_to_wandb,
            )
    else:
        logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")

    # 3. Setup local logging
    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
    # Set the verbosity to info of the Transformers logger.
    # We only want one process per machine to log things on the screen.
    logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
    if jax.process_index() == 0:
        datasets.utils.logging.set_verbosity_warning()
        transformers.utils.logging.set_verbosity_info()
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

    logger.info("Training/evaluation parameters %s", training_args)

    # Check the output dir is valid
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not"
            " empty. Use `--overwrite_output_dir` to overcome."
        )

    # 4. Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name,
                token=training_args.hub_token,
            )
        else:
            repo_name = training_args.hub_model_id
        create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
        repo = Repository(
            training_args.output_dir,
            clone_from=repo_name,
            token=training_args.hub_token,
        )

    if training_args.compilation_cache:
        cc.initialize_cache(os.path.join(model_args.cache_dir, "jax_cache"))

    # 5. Load dataset
    raw_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()

    # set seed for determinism
    set_seed(training_args.seed)

    if training_args.do_train:
        raw_datasets["train"] = load_multiple_datasets(
            data_args.train_dataset_name,
            data_args.train_dataset_config_name,
            splits=data_args.train_split_name,
            streaming=data_args.streaming,
            dataset_samples=data_args.train_dataset_samples,
            seed=training_args.seed,
            cache_dir=data_args.dataset_cache_dir,
            token=True if model_args.use_auth_token else None,
        )

    if training_args.do_eval:
        dataset_names_dict = convert_dataset_str_to_list(
            data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
            (
                data_args.eval_dataset_config_name
                if data_args.eval_dataset_config_name
                else data_args.train_dataset_config_name
            ),
            splits=data_args.eval_split_name,
            text_column_names=data_args.eval_text_column_name,
        )
        all_eval_splits = []
        if len(dataset_names_dict) == 1:
            # load a single eval set
            dataset_dict = dataset_names_dict[0]
            all_eval_splits.append("eval")
            raw_datasets["eval"] = load_dataset(
                dataset_dict["name"],
                dataset_dict["config"],
                split=dataset_dict["split"],
                cache_dir=data_args.dataset_cache_dir,
                token=True if model_args.use_auth_token else None,
                streaming=data_args.streaming,
            )
        else:
            # load multiple eval sets
            for dataset_dict in dataset_names_dict:
                if dataset_dict["name"] == "esb/diagnostic-dataset":
                    # for the ESB diagnostic dataset, the dataset name is effectively the config
                    pretty_name = f"{dataset_dict['config']}-diagnostic/{dataset_dict['split']}"
                else:
                    pretty_name = f"{dataset_dict['name'].split('/')[-1]}/{dataset_dict['split'].replace('.', '-')}"
                all_eval_splits.append(pretty_name)
                raw_datasets[pretty_name] = load_dataset(
                    dataset_dict["name"],
                    dataset_dict["config"],
                    split=dataset_dict["split"],
                    cache_dir=data_args.dataset_cache_dir,
                    token=True if model_args.use_auth_token else None,
                    streaming=data_args.streaming,
                )
                features = raw_datasets[pretty_name].features.keys()
                if "text" not in features:
                    raw_datasets[pretty_name] = raw_datasets[pretty_name].rename_column(
                        dataset_dict["text_column_name"], "text"
                    )
                raw_datasets[pretty_name] = raw_datasets[pretty_name].remove_columns(
                    set(raw_datasets[pretty_name].features.keys()) - {"audio", "text"}
                )

    if not training_args.do_train and not training_args.do_eval:
        raise ValueError(
            "Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
        )

    raw_datasets_train_features = list(raw_datasets["train"].features.keys())

    if data_args.audio_column_name not in raw_datasets_train_features:
        raise ValueError(
            f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
            " the correct audio column - one of"
            f" {', '.join(raw_datasets_train_features)}."
        )

    if data_args.train_text_column_name not in raw_datasets_train_features:
        raise ValueError(
            f"--train_text_column_name {data_args.train_text_column_name} not found in dataset"
            f" '{data_args.dataset_name}'. Make sure to set `--train_text_column_name` to the"
            " correct text column - one of"
            f" {', '.join(raw_datasets_train_features)}."
        )

    # 6. Load pretrained model, tokenizer, and feature extractor
    config = WhisperConfig.from_pretrained(
        (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
    )
    feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(
        (model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
    )
    tokenizer = WhisperTokenizerFast.from_pretrained(
        (model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
    )

    # override timestamp tokens until tokenizer issues are fixed in transformers
    timestamps = [AddedToken("<|%.2f|>" % (i * 0.02), lstrip=False, rstrip=False) for i in range(1500 + 1)]
    tokenizer.add_tokens(timestamps)

    config.update(
        {
            "activation_dropout": model_args.activation_dropout,
            "attention_dropout": model_args.attention_dropout,
            "dropout": model_args.dropout,
        }
    )

    if training_args.precision == "full_mixed":
        # forward pass, backward pass and optimiser states in bf16
        dtype = jnp.bfloat16
        to_dtype = to_bf16
    elif training_args.precision == "half_mixed" or model_args.dtype == "bfloat16":
        # forward pass in bf16, backward pass and optimiser states in fp32
        dtype = jnp.bfloat16
        to_dtype = to_fp32
    else:
        if training_args.precision != "full":
            raise ValueError(
                f"`precision` should be one of: `full`, `half_mixed` or `full_mixed`, got {training_args.precision}"
            )
        # forward pass, backward pass and optimiser states in fp32
        dtype = jnp.float32
        to_dtype = to_fp32

    student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        dtype=dtype,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        subfolder=model_args.subfolder,
        token=True if model_args.use_auth_token else None,
        _do_init=False,
        use_scan=model_args.load_with_scan_weights,
    )

    teacher_model, teacher_params = FlaxWhisperForConditionalGeneration.from_pretrained(
        model_args.teacher_model_name_or_path,
        # config=config,
        dtype=dtype,
        cache_dir=model_args.cache_dir,
        # revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
        _do_init=False,
    )

    if student_model.config.decoder_start_token_id is None or teacher_model.config.decoder_start_token_id is None:
        raise ValueError(
            f"Make sure that `config.decoder_start_token_id` is correctly defined for both the "
            f"student and teacher model. Got {student_model.config.decoder_start_token_id} for the "
            f"student and {teacher_model.config.decoder_start_token_id} for the teacher."
        )

    # enable scan / gradient checkpointing if necessary
    if training_args.use_scan:
        student_model.enable_scan()  # to enable scan in the nn.Module
        student_params = student_model.convert_unroll_to_scan(student_params)  # to convert the unrolled params to scan

        teacher_model.enable_scan()  # faster compile time (even though we don't train the teacher)
        teacher_params = teacher_model.convert_unroll_to_scan(teacher_params)

    if training_args.gradient_checkpointing:
        student_model.enable_gradient_checkpointing()  # to enable checkpointing in the nn.Module, there is no change to the params structure
        teacher_model.enable_gradient_checkpointing()

    if hasattr(teacher_model.generation_config, "is_multilingual") and teacher_model.generation_config.is_multilingual:
        # We need to set the language and task ids for previously multilingual checkpoints - for now we hardcode this to English
        tokenizer.set_prefix_tokens(language="English", task="transcribe", predict_timestamps=False)
        student_model.generation_config.update(
            **{
                "language": "<|en|>",
                "task": "transcribe",
            }
        )

    # 7. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
    # so we just need to set the correct target sampling rate.
    raw_datasets = raw_datasets.cast_column(
        data_args.audio_column_name,
        datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
    )

    # 8. Preprocessing the datasets.
    # We need to read the audio files as arrays and tokenize the targets.
    max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
    min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
    max_label_length = (
        data_args.max_label_length if data_args.max_label_length is not None else student_model.config.max_length
    )
    audio_column_name = data_args.audio_column_name
    num_workers = data_args.preprocessing_num_workers
    dataloader_num_workers = training_args.dataloader_num_workers
    dataloader_prefetch_size = data_args.prefetch_size
    train_text_column_name = data_args.train_text_column_name
    eval_text_column_name = "text"
    model_input_name = feature_extractor.model_input_names[0]
    normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
    wer_threshold = data_args.wer_threshold
    round_timestamps = data_args.round_timestamps

    if training_args.do_train and data_args.max_train_samples is not None:
        raw_datasets["train"] = (
            raw_datasets["train"].take(data_args.max_train_samples)
            if data_args.streaming
            else raw_datasets["train"].select(range(data_args.max_train_samples))
        )

    if training_args.do_eval and data_args.max_eval_samples is not None:
        for eval_split in all_eval_splits:
            raw_datasets[eval_split] = (
                raw_datasets[eval_split].take(data_args.max_eval_samples)
                if data_args.streaming
                else raw_datasets[eval_split].select(range(data_args.max_eval_samples))
            )

    def is_wer_in_range(ground_truth, whisper_transcript):
        norm_ground_truth = normalizer(ground_truth)
        if len(norm_ground_truth) > 0 and whisper_transcript is not None:
            norm_whisper_transcript = normalizer(whisper_transcript)
            wer = 100 * metric.compute(predictions=[norm_whisper_transcript], references=[norm_ground_truth])
            return wer < wer_threshold
        else:
            # filter automatically since we can't know the WER
            return False

    filter_by_wer_threshold = partial(
        raw_datasets["train"].filter,
        function=is_wer_in_range,
        input_columns=[eval_text_column_name, train_text_column_name],
    )

    if wer_threshold is not None:
        raw_datasets["train"] = (
            filter_by_wer_threshold(num_proc=num_workers, desc="filtering train dataset by wer")
            if not data_args.streaming
            else filter_by_wer_threshold()
        )

    def has_timestamp_tokens(input_str):
        """
        Identify whether the input string contains timestamp tokens, of the form <|0.00|>, by searching for
        pairs of left and right-angle brackets.
        """
        return bool(re.search("\<[^\>]*\>", input_str))

    def round_timestamp_tokens(input_str: str, ndigits: int = 1):
        timestamps = re.findall("\<[^\>]*\>", input_str, re.DOTALL)
        for token in timestamps:
            # extract time digits from timestamp token, e.g. <|6.24|> to 6.24
            time_digit = token[2:-2]
            # round to specified number of digits, e.g. 6.24 to 6.2
            time_digit = round(float(time_digit), ndigits=ndigits)
            # replace in original string with the same precision, e.g. <|6.24|> to <|6.20|>
            input_str = input_str.replace(token, "<|{:.2f}|>".format(time_digit))
        return input_str

    def prepare_train_dataset(batch):
        # process audio input
        sample = batch[audio_column_name]
        inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
        batch[model_input_name] = inputs.get(model_input_name)[0]
        batch["input_length"] = len(sample["array"])

        # process text targets
        input_str = batch[train_text_column_name]

        # prompt & timestamp processing: for now, we only do one or the other
        if input_str.startswith("<|startoftranscript|>") or input_str.startswith("<|startofprev|>"):
            # prompted target text already has special ids added, so don't add them here
            batch["labels"] = tokenizer(input_str, add_special_tokens=False).input_ids
            return batch

        has_timestamps = has_timestamp_tokens(input_str)

        if has_timestamps:
            predict_timestamps = bool(np.random.binomial(1, data_args.timestamp_probability))
            if not predict_timestamps:
                # filter timestamp token ids if not part of the prediction task
                input_str = tokenizer._filter_timestamp_ids(input_str)
            elif round_timestamps:
                input_str = round_timestamp_tokens(input_str)
        else:
            predict_timestamps = False

        tokenizer.set_prefix_tokens(language="English", task="transcribe", predict_timestamps=predict_timestamps)
        input_ids = tokenizer(input_str).input_ids
        batch["labels"] = input_ids
        return batch

    def prepare_eval_dataset(batch):
        # process audio
        sample = batch[audio_column_name]
        inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
        # process audio length
        batch[model_input_name] = inputs.get(model_input_name)[0]
        batch["input_length"] = len(sample["array"])

        # process targets
        input_str = batch[eval_text_column_name]
        batch["labels"] = tokenizer(input_str).input_ids
        return batch

    vectorized_datasets = IterableDatasetDict() if data_args.streaming else DatasetDict()
    if training_args.do_train:
        map_fn_train = partial(
            raw_datasets["train"].map, function=prepare_train_dataset, remove_columns=raw_datasets_train_features
        )
        vectorized_datasets["train"] = (
            map_fn_train(num_proc=num_workers, desc="preprocess train dataset")
            if not data_args.streaming
            else map_fn_train()
        )
    if training_args.do_eval:
        for eval_split in all_eval_splits:
            raw_datasets_eval_features = list(raw_datasets[eval_split].features.keys())
            map_fn_eval = partial(
                raw_datasets[eval_split].map, function=prepare_eval_dataset, remove_columns=raw_datasets_eval_features
            )
            vectorized_datasets[eval_split] = (
                map_fn_eval(num_proc=num_workers, desc="preprocess eval dataset")
                if not data_args.streaming
                else map_fn_eval()
            )

    # filter training data with inputs longer than max_input_length
    def is_audio_in_length_range(length):
        return min_input_length < length < max_input_length

    filter_by_audio_fn = partial(
        vectorized_datasets.filter, function=is_audio_in_length_range, input_columns=["input_length"]
    )
    vectorized_datasets = (
        filter_by_audio_fn(num_proc=num_workers, desc="filtering train dataset by audio length")
        if not data_args.streaming
        else filter_by_audio_fn()
    )

    # filter training data with labels longer than max_label_length
    def is_labels_in_length_range(labels):
        return 0 < len(labels) < max_label_length

    filter_by_labels_fn = partial(
        vectorized_datasets.filter, function=is_labels_in_length_range, input_columns=["labels"]
    )
    vectorized_datasets = (
        filter_by_labels_fn(num_proc=num_workers, desc="filtering train dataset")
        if not data_args.streaming
        else filter_by_labels_fn()
    )

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with `args.preprocessing_only` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step `args.preprocessing_only` can then be set to `False` to load the
    # cached dataset
    if data_args.preprocessing_only:
        cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
        logger.info(f"Data preprocessing finished. Files cached at {cache}.")
        return

    # 8. Load Metric
    metric = evaluate.load("wer")
    # convention is that we space all punctuation *except* apostrophes
    all_punctuation = list(string.punctuation.replace("'", ""))
    return_timestamps = data_args.return_timestamps if data_args.timestamp_probability > 0 else False

    def compute_metrics(preds, labels):
        # replace padded labels by the padding token
        for idx in range(len(labels)):
            labels[idx][labels[idx] == -100] = tokenizer.pad_token_id

        pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True, decode_with_timestamps=return_timestamps)
        # we do not want to group tokens when computing the metrics
        label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
        spaced_pred_str = [
            pred_str[i].replace(punctuation, f" {punctuation} ")
            for punctuation in all_punctuation
            for i in range(len(pred_str))
        ]
        spaced_label_str = [
            label_str[i].replace(punctuation, f" {punctuation} ")
            for punctuation in all_punctuation
            for i in range(len(label_str))
        ]
        wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)

        # normalize everything and re-compute the WER
        norm_pred_str = [normalizer(pred) for pred in pred_str]
        norm_label_str = [normalizer(label) for label in label_str]
        # for logging, we need the pred/labels to match the norm_pred/norm_labels, so discard any filtered samples here
        pred_str = [pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        label_str = [label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
        # filtering step to only evaluate the samples that correspond to non-zero normalized references:
        norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
        norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]

        wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)

        return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str, norm_pred_str, norm_label_str

    # 9. Save feature extractor, tokenizer, config and generation config
    feature_extractor.save_pretrained(training_args.output_dir)
    tokenizer.save_pretrained(training_args.output_dir)
    config.save_pretrained(training_args.output_dir)
    student_model.generation_config.save_pretrained(
        training_args.output_dir
    )  # generation config stays bound to model to make it easy to jit

    processor = WhisperProcessor.from_pretrained(training_args.output_dir)

    data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
        processor=processor,
        decoder_start_token_id=student_model.config.decoder_start_token_id,  # <|startoftranscript|>
        decoder_prev_token_id=tokenizer.all_special_ids[-3],  # <|startofprev|>
        input_padding="longest",
        target_padding="max_length",
        max_target_length=max_label_length,
    )

    # Initialize our training
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    # Store some constants
    train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
    gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
    eval_batch_size = per_device_eval_batch_size * jax.device_count()

    if not data_args.streaming and training_args.max_steps < 0:
        num_epochs = int(training_args.num_train_epochs)
        steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
        total_train_steps = steps_per_epoch * num_epochs
    elif training_args.max_steps > 0:
        logger.info("max_steps is given, it will override any value given in num_train_epochs")
        total_train_steps = int(training_args.max_steps)
        # Setting a very large number of epochs so we go as many times as necessary over the iterator.
        num_epochs = sys.maxsize
        steps_per_epoch = total_train_steps
    else:
        raise ValueError("max_steps must be specified when training with a streaming (iterable) dataset")

    if training_args.eval_steps is None:
        logger.info(
            f"eval_steps is not set, evaluating at the end of {'each epoch' if not data_args.streaming else 'training'}"
        )
        eval_steps = steps_per_epoch
    else:
        eval_steps = training_args.eval_steps

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        total_train_steps * gradient_accumulation_steps,
        training_args.lr_scheduler_type,
        training_args.warmup_steps * gradient_accumulation_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = [
            "layer_norm",
            "self_attn_layer_norm",
            "final_layer_norm",
            "encoder_attn_layer_norm",
        ]
        layer_norm_named_params = {
            layer[-2:]
            for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        }
        flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    if gradient_accumulation_steps > 1:
        # accumulate gradients and apply once every k steps
        adamw = optax.MultiSteps(adamw, every_k_schedule=gradient_accumulation_steps)

    share_hidden_states = training_args.freeze_encoder and student_model.config.d_model == teacher_model.config.d_model
    encoder_layer_mapping = get_layers_to_supervise(
        student_model.config.encoder_layers, teacher_model.config.encoder_layers
    )
    decoder_layer_mapping = get_layers_to_supervise(
        student_model.config.decoder_layers, teacher_model.config.decoder_layers
    )

    # Setup train state
    student_state = TrainState.create(
        apply_fn=student_model.decode if share_hidden_states else student_model.__call__,
        params=student_params,
        tx=adamw,
        to_dtype=to_dtype,
        dropout_rng=dropout_rng,
        max_grad_norm=training_args.max_grad_norm,
    )

    if training_args.resume_from_checkpoint is not None:
        if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
            logger.info(
                f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
                "this behavior, omit the resume_from_checkpoint argument."
            )
            with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
                student_state = from_bytes(student_state, f.read())
        else:
            logger.warning(
                f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
                f"you pass the path to a folder with a valid checkpoint for your model."
            )

    def cross_entropy_loss(logits, labels):
        vocab_size = logits.shape[-1]
        # optax onehot always returns a float32 device array, need to downcast if performing mixed precision training
        onehot_targets = to_dtype(onehot(labels, vocab_size))
        loss = optax.softmax_cross_entropy(logits, onehot_targets)
        # ignore padded tokens from loss, i.e. where labels are not set to -100
        padding = labels >= 0
        loss = loss * padding
        loss = loss.sum()
        num_labels = padding.sum()
        return loss, num_labels

    # temperature smoothed kl-divergence
    def kl_divergence(target_distribution, log_predicted_distribution, labels, eps=1e-20):
        divergence = -target_distribution * (log_predicted_distribution - jnp.log(target_distribution + eps))
        # ignore padded tokens from divergence, i.e. where labels are not set to -100
        padding_mask = labels >= 0
        padding_mask = jnp.expand_dims(padding_mask, axis=-1)
        divergence = (divergence * padding_mask).sum()
        return to_dtype(divergence)  # respect the dtype of the backprop

    def mean_square_error_loss(student_outputs, teacher_outputs):
        mse = dtype(0.0)

        # tie encoder embeddings
        mse += jnp.mean(
            jnp.square(teacher_outputs.encoder_hidden_states[0] - student_outputs.encoder_hidden_states[0])
        )

        for student_layer_id, teacher_layer_id in encoder_layer_mapping.items():
            # offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
            student_hidden_state = student_outputs.encoder_hidden_states[student_layer_id + 1]
            teacher_hidden_state = teacher_outputs.encoder_hidden_states[teacher_layer_id + 1]
            mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))

            # student_attention = student_outputs.encoder_attentions[student_layer_id]
            # teacher_attention = teacher_outputs.encoder_attentions[teacher_layer_id]
            # mse += jnp.mean(jnp.square(student_attention - teacher_attention))

        # tie decoder embeddings
        mse += jnp.mean(
            jnp.square(teacher_outputs.decoder_hidden_states[0] - student_outputs.decoder_hidden_states[0])
        )

        for student_layer_id, teacher_layer_id in decoder_layer_mapping.items():
            # offset the hidden-state layer ids by 1 to account for the extra embedding hidden-state
            student_hidden_state = student_outputs.decoder_hidden_states[student_layer_id + 1]
            teacher_hidden_state = teacher_outputs.decoder_hidden_states[teacher_layer_id + 1]
            mse += jnp.mean(jnp.square(teacher_hidden_state - student_hidden_state))

            # student_attention = student_outputs.decoder_attentions[student_layer_id]
            # teacher_attention = teacher_outputs.decoder_attentions[teacher_layer_id]
            # mse += jnp.mean(jnp.square(student_attention - teacher_attention))

            # student_cross_attention = student_outputs.cross_attentions[student_layer_id]
            # teacher_cross_attention = teacher_outputs.cross_attentions[teacher_layer_id]
            # mse += jnp.mean(jnp.square(student_cross_attention - teacher_cross_attention))

        return to_dtype(mse)  # respect the dtype of the backprop

    # Define gradient update step fn
    def train_step(
        student_state,
        teacher_params,
        batch,
        freeze_encoder,
        share_hidden_states,
        temperature=2.0,
    ):
        dropout_rng, new_dropout_rng = jax.random.split(student_state.dropout_rng)

        def compute_loss(student_params):
            labels = batch.pop("labels")
            output_hidden_states = not share_hidden_states and training_args.mse_weight > 0.0

            teacher_outputs = teacher_model(
                **batch,
                params=teacher_params,
                freeze_encoder=True,
                output_hidden_states=output_hidden_states,
                train=False,
            )

            if share_hidden_states:
                # if the student and teacher share the same frozen encoder then we don't have to recompute the
                # encoder hidden-states for the student model, we can just re-use from the teacher
                encoder_hidden_states = jax.lax.stop_gradient(teacher_outputs.encoder_last_hidden_state)
                encoder_outputs = FlaxBaseModelOutput(last_hidden_state=encoder_hidden_states)

                student_outputs = student_state.apply_fn(
                    decoder_input_ids=batch["decoder_input_ids"],
                    encoder_outputs=encoder_outputs,
                    params=student_params,
                    dropout_rng=dropout_rng,
                    train=True,
                )
            else:
                # do the full forward pass for the student model (encoder + decoder)
                student_outputs = student_state.apply_fn(
                    **batch,
                    params=student_params,
                    dropout_rng=dropout_rng,
                    freeze_encoder=freeze_encoder,
                    output_hidden_states=output_hidden_states,
                    train=True,
                )

            # CE (data) loss
            ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)

            # rescale by temperature to ensure gradients scale correctly
            teacher_distribution = jax.nn.softmax(teacher_outputs.logits / temperature, axis=-1)
            # ensure no information flow backwards through teacher
            teacher_distribution = jax.lax.stop_gradient(teacher_distribution)
            # log softmax of student predictions for numerical stability
            student_distribution = jax.nn.log_softmax(student_outputs.logits / temperature, axis=-1)
            # KL-divergence loss (scaled by temperature)
            kl_loss = kl_divergence(teacher_distribution, student_distribution, labels) * temperature**2

            # MSE loss between enc-dec hidden-states and attentions
            mse_loss = (
                mean_square_error_loss(student_outputs, teacher_outputs)
                if output_hidden_states
                else jnp.zeros_like(kl_loss)
            )

            # use DistilBart formulation - only tune the MSE weight and take remaining HPs from DistilBERT
            ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
            loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss

            return loss, (
                ce_loss,
                kl_loss,
                mse_loss,
                num_labels,
            )

        grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
        (loss, (ce_loss, kl_loss, mse_loss, num_labels)), grad = grad_fn(to_dtype(student_state.params))

        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        num_labels = jax.lax.psum(num_labels, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        # true grad = total grad / total samples
        grad = jax.lax.psum(grad, "batch")
        grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
        new_state = student_state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng, to_dtype=to_dtype)

        # CE/KL/MSE losses for logging
        ce_loss = jax.lax.psum(ce_loss, "batch")
        ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)

        kl_loss = jax.lax.psum(kl_loss, "batch")
        kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)

        mse_loss = jax.lax.psum(mse_loss, "batch")
        mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)

        metrics = {
            "loss": loss,
            "learning_rate": linear_decay_lr_schedule_fn(student_state.step),
            "ce_loss": ce_loss,
            "kl_loss": kl_loss,
            "mse_loss": mse_loss,
        }
        return new_state, metrics

    # Define eval fn
    def eval_step(student_params, teacher_params, batch):
        labels = batch.pop("labels")
        output_hidden_states = not share_hidden_states and training_args.mse_weight > 0

        student_outputs = student_model(
            **batch,
            params=student_params,
            output_hidden_states=output_hidden_states,
            train=False,
        )
        student_distribution = jax.nn.log_softmax(student_outputs.logits, axis=-1)
        ce_loss, num_labels = cross_entropy_loss(student_outputs.logits, labels)

        teacher_outputs = teacher_model(
            **batch,
            params=teacher_params,
            output_hidden_states=output_hidden_states,
            train=False,
        )
        teacher_distribution = jax.nn.softmax(teacher_outputs.logits, axis=-1)
        # temperature is always 1 for eval
        kl_loss = kl_divergence(teacher_distribution, student_distribution, labels)

        mse_loss = (
            mean_square_error_loss(student_outputs, teacher_outputs)
            if output_hidden_states
            else jnp.zeros_like(kl_loss)
        )

        ce_weight = 0.8 if training_args.kl_weight > 0 else 1.0
        loss = ce_weight * ce_loss + training_args.kl_weight * kl_loss + training_args.mse_weight * mse_loss
        # true loss = total loss / total samples
        loss = jax.lax.psum(loss, "batch")
        num_labels = jax.lax.psum(num_labels, "batch")
        loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)

        # CE/KL/MSE losses for logging
        ce_loss = jax.lax.psum(ce_loss, "batch")
        ce_loss = jax.tree_util.tree_map(lambda x: x / num_labels, ce_loss)

        kl_loss = jax.lax.psum(kl_loss, "batch")
        kl_loss = jax.tree_util.tree_map(lambda x: x / num_labels, kl_loss)

        mse_loss = jax.lax.psum(mse_loss, "batch")
        mse_loss = jax.tree_util.tree_map(lambda x: x / num_labels, mse_loss)

        metrics = {"loss": loss, "ce_loss": ce_loss, "kl_loss": kl_loss, "mse_loss": mse_loss}
        return metrics

    # Define generation function
    num_beams = (
        training_args.generation_num_beams
        if training_args.generation_num_beams is not None
        else student_model.config.num_beams
    )

    # forcing the language and task tokens helps the model in its generations
    gen_kwargs = {
        "max_length": max_label_length,
        "num_beams": num_beams,
        "language": "<|en|>",
        "task": "transcribe",
        "return_timestamps": return_timestamps,
    }

    def generate_step(student_params, batch):
        output_ids = student_model.generate(
            batch[model_input_name],
            attention_mask=batch.get("attention_mask"),
            params=student_params,
            **gen_kwargs,
        )
        return output_ids.sequences

    # Replicate the train state on each device
    student_state = student_state.replicate()

    # Replicate the teacher params on each device
    teacher_params = jax_utils.replicate(teacher_params)

    # Create parallel version of the train and eval step
    p_train_step = jax.pmap(
        train_step,
        "batch",
        in_axes=(0, 0, 0, None, None, None),
        donate_argnums=(0,),
        static_broadcasted_argnums=(
            3,
            4,
        ),
    )
    p_eval_step = jax.pmap(eval_step, "batch")
    p_generate_step = jax.pmap(generate_step, "batch")

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
    logger.info("  Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
    logger.info("  Gradient accumulation steps =" f" {gradient_accumulation_steps}")
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    # ======================== Training ================================
    train_time = 0
    train_start = time.time()
    train_metrics = []
    batches_to_skip = jax.device_get(unreplicate(student_state.step))
    cur_step = int(batches_to_skip)  # will be zero if starting from scratch
    epochs_trained = batches_to_skip // steps_per_epoch
    steps_trained_progress_bar = tqdm(range(total_train_steps), desc="Train steps ... ", position=0)
    steps_trained_progress_bar.update(batches_to_skip)
    continue_training = True
    minibatch_steps = 0

    if batches_to_skip > 0:
        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info(f"  Continuing training from epoch {epochs_trained}")
        logger.info(f"  Continuing training from global step {batches_to_skip}")

    # Generate a training data loader by shuffling sampling indices from the train dataset
    train_loader = get_data_loader(
        training_args.seed,
        vectorized_datasets["train"],
        batch_size=train_batch_size,
        data_collator=data_collator,
        dataloader_num_workers=dataloader_num_workers,
        skip_batches=batches_to_skip,
        prefetch_size=dataloader_prefetch_size,
    )

    for epoch in range(epochs_trained, num_epochs):
        if hasattr(train_loader, "dataset") and isinstance(train_loader.dataset, IterableDataset):
            train_loader.dataset.set_epoch(epoch)

        for batch in train_loader:
            minibatch_steps += 1
            update_step = minibatch_steps == gradient_accumulation_steps

            if update_step:
                steps_trained_progress_bar.update(1)
                cur_step += 1
                minibatch_steps = 0

            batch = shard(batch.data)
            student_state, train_metric = p_train_step(
                student_state,
                teacher_params,
                batch,
                training_args.freeze_encoder,
                share_hidden_states,
                training_args.temperature,
            )

            if cur_step % training_args.logging_steps == 0 and update_step:
                train_metrics.append(train_metric)
                train_metric_to_write = unreplicate(train_metric)
                steps_trained_progress_bar.write(
                    f"Step... ({cur_step} / {total_train_steps} | Loss:"
                    f" {train_metric_to_write['loss']}, Learning Rate:"
                    f" {train_metric_to_write['learning_rate']})"
                )
                if has_wandb and jax.process_index() == 0:
                    write_wandb_metric(
                        wandb_logger,
                        train_metric_to_write,
                        train_time + time.time() - train_start,
                        cur_step,
                        epoch,
                        prefix="train",
                    )

            # save checkpoint and weights after each save_steps and at the end of training
            if (cur_step % training_args.save_steps == 0 and update_step) or cur_step == total_train_steps:
                if jax.process_index() == 0:
                    save_hf_weights(
                        student_state,
                        student_model,
                        processor,
                        training_args.output_dir,
                        cur_step,
                        total_train_steps,
                        use_scan=training_args.use_scan,
                    )
                    if training_args.save_train_state:
                        student_state.save_state(
                            training_args.output_dir, save_total_limit=training_args.save_total_limit
                        )
                    if training_args.push_to_hub:
                        repo.push_to_hub(
                            commit_message=f"Saving train state of step {cur_step}",
                            blocking=False,
                        )

            if training_args.do_eval and (
                (cur_step % eval_steps == 0 and update_step) or cur_step == total_train_steps
            ):
                train_time += time.time() - train_start
                # ======================== Evaluating ==============================
                for eval_split in all_eval_splits:
                    eval_metrics = []
                    eval_preds = []
                    eval_labels = []
                    eval_start = time.time()

                    eval_loader = get_data_loader(
                        training_args.seed,
                        vectorized_datasets[eval_split],
                        batch_size=eval_batch_size,
                        data_collator=data_collator,
                        shuffle=False,
                        drop_last=False,
                        dataloader_num_workers=dataloader_num_workers,
                    )
                    for batch in tqdm(eval_loader, desc=f"Evaluating {eval_split}...", position=2):
                        # Model forward
                        labels = batch["labels"]

                        metrics = pad_shard_unpad(
                            p_eval_step,
                            static_argnums=(
                                0,
                                1,
                            ),
                            static_return=True,
                        )(
                            student_state.params,
                            teacher_params,
                            batch.data,
                            min_device_batch=per_device_eval_batch_size,
                        )
                        eval_metrics.append(metrics)

                        # generation
                        if training_args.predict_with_generate:
                            generated_ids = pad_shard_unpad(p_generate_step)(
                                student_state.params, batch.data, min_device_batch=per_device_eval_batch_size
                            )
                            eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
                            eval_labels.extend(labels)

                    eval_time = time.time() - eval_start

                    # normalize eval metrics
                    eval_metrics = get_metrics(eval_metrics)
                    eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)

                    # compute WER metric
                    wer_desc = ""
                    if training_args.predict_with_generate:
                        wer_metric, pred_str, label_str, norm_pred_str, norm_label_str = compute_metrics(
                            eval_preds, eval_labels
                        )
                        eval_metrics.update(wer_metric)
                        wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])

                    # Print metrics and update progress bar
                    steps_trained_progress_bar.write(
                        f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
                        f" {wer_desc})"
                    )

                    if has_tensorboard and jax.process_index() == 0:
                        write_eval_metric(
                            summary_writer,
                            eval_metrics,
                            cur_step,
                            prefix=eval_split,
                        )

                    if has_wandb and jax.process_index() == 0:
                        write_wandb_metric(wandb_logger, eval_metrics, eval_time, cur_step, epoch, prefix=eval_split)
                        if training_args.predict_with_generate:
                            write_wandb_pred(
                                wandb_logger,
                                pred_str,
                                label_str,
                                norm_pred_str,
                                norm_label_str,
                                cur_step,
                                prefix=eval_split,
                            )

                if has_tensorboard and jax.process_index() == 0:
                    # we'll only log to tensorboard every eval steps
                    write_train_metric(
                        summary_writer,
                        train_metrics,
                        train_time,
                        cur_step,
                        training_args.logging_steps,
                    )

                # flush the train metrics
                train_start = time.time()
                train_metrics = []

            # break condition
            if cur_step == total_train_steps:
                continue_training = False
                break

        if not continue_training:
            break