def parse_args()

in scripts/setfit/run_fewshot_distillation.py [0:0]


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--teacher_model", default="paraphrase-mpnet-base-v2")
    parser.add_argument("--student_model", default="paraphrase-MiniLM-L3-v2")
    parser.add_argument("--baseline_student_model", default="nreimers/MiniLM-L3-H384-uncased")
    parser.add_argument(
        "--datasets",
        nargs="+",
        default=["sst2"],
    )
    parser.add_argument("--teacher_sample_sizes", type=int, nargs="+", default=[16])
    parser.add_argument(
        "--student_sample_sizes",
        type=int,
        nargs="+",
        default=[8, 16, 32, 64, 100, 200, 1000],
    )
    parser.add_argument("--num_iterations_teacher", type=int, default=20)
    parser.add_argument("--num_iterations_student", type=int, default=20)
    parser.add_argument("--num_epochs", type=int, default=1)
    parser.add_argument("--batch_size_teacher", type=int, default=16)
    parser.add_argument("--batch_size_student", type=int, default=16)
    parser.add_argument("--max_seq_length", type=int, default=256)
    parser.add_argument("--baseline_model_epochs", type=int, default=10)
    parser.add_argument("--baseline_model_batch_size", type=int, default=16)

    parser.add_argument(
        "--classifier",
        default="logistic_regression",
        choices=[
            "logistic_regression",
            "svc-rbf",
            "svc-rbf-norm",
            "knn",
            "pytorch",
            "pytorch_complex",
        ],
    )
    parser.add_argument("--loss", default="CosineSimilarityLoss")
    parser.add_argument("--exp_name", default="")
    parser.add_argument("--add_normalization_layer", default=False, action="store_true")
    parser.add_argument("--optimizer_name", default="AdamW")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--is_dev_set", type=bool, default=False)
    parser.add_argument("--is_test_set", type=bool, default=False)
    args = parser.parse_args()

    return args