def main()

in scripts/setfit/run_zeroshot.py [0:0]


def main():
    args = parse_args()

    parent_directory = pathlib.Path(__file__).parent.absolute()
    output_path = (
        parent_directory
        / "results"
        / f"{args.model.replace('/', '-')}-{args.loss}-{args.classifier}-iterations_{args.num_iterations}-batch_{args.batch_size}-{args.exp_name}".rstrip(
            "-"
        )
    )
    os.makedirs(output_path, exist_ok=True)

    # Save a copy of this training script and the run command in results directory
    train_script_path = os.path.join(output_path, "train_script.py")
    copyfile(__file__, train_script_path)
    with open(train_script_path, "a") as f_out:
        f_out.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))

    # Configure loss function
    loss_class = LOSS_NAME_TO_CLASS[args.loss]

    metric = DEV_DATASET_TO_METRIC.get(args.eval_dataset, TEST_DATASET_TO_METRIC.get(args.eval_dataset, "accuracy"))

    if args.reference_dataset is None and args.candidate_labels is None:
        args.reference_dataset = args.eval_dataset

    train_data = get_templated_dataset(
        reference_dataset=args.reference_dataset,
        candidate_labels=args.candidate_labels,
        sample_size=args.aug_sample_size,
        label_names_column=args.label_names_column,
    )
    test_data = load_dataset(args.eval_dataset, split="test")

    print(f"Evaluating {args.eval_dataset} using {metric!r}.")

    # Report on an imbalanced test set
    counter = Counter(test_data["label"])
    label_samples = sorted(counter.items(), key=lambda label_samples: label_samples[1])
    smallest_n_samples = label_samples[0][1]
    largest_n_samples = label_samples[-1][1]
    # If the largest class is more than 50% larger than the smallest
    if largest_n_samples > smallest_n_samples * 1.5:
        warnings.warn(
            "The test set has a class imbalance "
            f"({', '.join(f'label {label} w. {n_samples} samples' for label, n_samples in label_samples)})."
        )

    results_path = create_results_path(args.eval_dataset, "zeroshot", output_path)
    if os.path.exists(results_path) and not args.override_results:
        print(f"Skipping finished experiment: {results_path}")
        exit()

    # Load model
    if args.classifier == "pytorch":
        model = SetFitModel.from_pretrained(
            args.model,
            use_differentiable_head=True,
            head_params={"out_features": len(set(train_data["label"]))},
        )
    else:
        model = SetFitModel.from_pretrained(args.model)
    model.model_body.max_seq_length = args.max_seq_length
    if args.add_normalization_layer:
        model.model_body._modules["2"] = models.Normalize()

    # Train on current split
    trainer = SetFitTrainer(
        model=model,
        train_dataset=train_data,
        eval_dataset=test_data,
        metric=metric,
        loss_class=loss_class,
        batch_size=args.batch_size,
        num_epochs=args.num_epochs,
        num_iterations=args.num_iterations,
    )
    if args.classifier == "pytorch":
        trainer.freeze()
        trainer.train()
        trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
        trainer.train(
            num_epochs=25,
            body_learning_rate=1e-5,
            learning_rate=args.lr,  # recommend: 1e-2
            l2_weight=0.0,
            batch_size=args.batch_size,
        )
    else:
        trainer.train()

    # Evaluate the model on the test data
    metrics = trainer.evaluate()
    print(f"Metrics: {metrics}")

    with open(results_path, "w") as f_out:
        json.dump(
            {"score": metrics[metric] * 100, "measure": metric},
            f_out,
            sort_keys=True,
        )