def main()

in training/flax/convert_train_state_to_hf.py [0:0]


def main():
    # 1. Parse input arguments
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.
    parser = HfArgumentParser(
        (
            ModelArguments,
            Seq2SeqTrainingArguments,
        )
    )

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, training_args = parser.parse_args_into_dataclasses()

    # Handle the repository creation
    if training_args.push_to_hub:
        if training_args.hub_model_id is None:
            repo_name = get_full_repo_name(
                Path(training_args.output_dir).absolute().name,
                token=training_args.hub_token,
            )
        else:
            repo_name = training_args.hub_model_id
        create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
        repo = Repository(
            training_args.output_dir,
            clone_from=repo_name,
            token=training_args.hub_token,
        )

    # 5. Load pretrained config, model and processor
    config = AutoConfig.from_pretrained(
        (model_args.config_name if model_args.config_name else model_args.model_name_or_path),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )
    student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        dtype=getattr(jnp, model_args.dtype),
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
        _do_init=False,
        use_scan=model_args.load_with_scan_weights,
    )

    # enable scan / gradient checkpointing if necessary in the student model
    if model_args.use_scan:
        student_model.enable_scan()  # to enable scan in the nn.Module
        student_params = student_model.convert_unroll_to_scan(student_params)  # to convert the unrolled params to scan

    # Initialize our student state
    rng = jax.random.PRNGKey(training_args.seed)
    rng, dropout_rng = jax.random.split(rng)

    total_train_steps = int(training_args.max_steps)

    # Create learning rate schedule
    linear_decay_lr_schedule_fn = create_learning_rate_fn(
        total_train_steps,
        training_args.lr_scheduler_type,
        training_args.warmup_steps,
        training_args.learning_rate,
    )

    # We use Optax's "masking" functionality to not apply weight decay
    # to bias and LayerNorm scale parameters. decay_mask_fn returns a
    # mask boolean with the same structure as the parameters.
    # The mask is True for parameters that should be decayed.
    def decay_mask_fn(params):
        flat_params = traverse_util.flatten_dict(params)
        # find out all LayerNorm parameters
        layer_norm_candidates = [
            "layer_norm",
            "self_attn_layer_norm",
            "final_layer_norm",
            "encoder_attn_layer_norm",
        ]
        layer_norm_named_params = {
            layer[-2:]
            for layer_norm_name in layer_norm_candidates
            for layer in flat_params.keys()
            if layer_norm_name in "".join(layer).lower()
        }
        flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
        return traverse_util.unflatten_dict(flat_mask)

    # create adam optimizer
    adamw = optax.adamw(
        learning_rate=linear_decay_lr_schedule_fn,
        b1=training_args.adam_beta1,
        b2=training_args.adam_beta2,
        eps=training_args.adam_epsilon,
        weight_decay=training_args.weight_decay,
        mask=decay_mask_fn,
    )

    # Setup train state
    student_state = TrainState.create(
        apply_fn=student_model.__call__,
        params=student_params,
        tx=adamw,
        dropout_rng=dropout_rng,
        max_grad_norm=training_args.max_grad_norm,
    )

    if training_args.resume_from_checkpoint is not None:
        if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")):
            logger.info(
                f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid "
                "this behavior, omit the resume_from_checkpoint argument."
            )
            with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f:
                student_state = from_bytes(student_state, f.read())
        else:
            logger.warning(
                f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure "
                f"you pass the path to a folder with a valid checkpoint for your model."
            )

    cur_step = int(jax.device_get(student_state.step))

    # save weights in HF Transformers format
    if jax.process_index() == 0:
        student_model.disable_scan()
        student_state_params = student_model.convert_scan_to_unroll(student_state.params)
        student_params = jax.device_get(student_state_params)
        student_model.save_pretrained(
            os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params
        )
        if training_args.push_to_hub:
            repo.push_to_hub(
                commit_message=f"Saving weights of step {cur_step}",
                blocking=False,
            )