def parse_args()

in training/flax/create_student_model.py [0:0]


def parse_args():
    parser = argparse.ArgumentParser(
        description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary."
    )
    parser.add_argument(
        "--teacher_checkpoint",
        type=str,
        required=True,
        help="The HF Hub ID of the teacher checkpoint.",
    )
    parser.add_argument(
        "--subfolder",
        type=str,
        default="",
        help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you "
        "can specify the folder name here.",
    )
    parser.add_argument(
        "--encoder_layers",
        type=int,
        default=None,
        help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.",
    )
    parser.add_argument(
        "--decoder_layers",
        type=int,
        default=2,
        help="Number of decoder layers to use in the student model. Defaults to 2 layers.",
    )
    parser.add_argument(
        "--max_source_positions",
        type=int,
        default=None,
        help="The maximum sequence length of log-mel filter-bank features that this model might ever be used with. Can "
        "be used to create a student model with a shorter context length than the teacher model. Defaults to the number "
        "of source positions in the teacher model (1500).",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        required=True,
        help="Where to save the student weights and processor.",
    )
    parser.add_argument(
        "--push_to_hub",
        type=bool,
        required=False,
        default=False,
        help="Whether to push the student weights and processor to the Hub.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="Where to store the pretrained models downloaded from huggingface.co",
    )

    args = parser.parse_args()
    return args