def main()

in scripts/setfit/run_fewshot.py [0:0]


def main():
    args = parse_args()

    parent_directory = pathlib.Path(__file__).parent.absolute()
    output_path = (
        parent_directory
        / "results"
        / f"{args.model.replace('/', '-')}-{args.loss}-{args.classifier}-iterations_{args.num_iterations}-batch_{args.batch_size}-{args.exp_name}".rstrip(
            "-"
        )
    )
    os.makedirs(output_path, exist_ok=True)

    # Save a copy of this training script and the run command in results directory
    train_script_path = os.path.join(output_path, "train_script.py")
    copyfile(__file__, train_script_path)
    with open(train_script_path, "a") as f_out:
        f_out.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))

    # Configure dataset <> metric mapping. Defaults to accuracy
    if args.is_dev_set:
        dataset_to_metric = DEV_DATASET_TO_METRIC
    elif args.is_test_set:
        dataset_to_metric = TEST_DATASET_TO_METRIC
    else:
        dataset_to_metric = {dataset: "accuracy" for dataset in args.datasets}

    # Configure loss function
    loss_class = LOSS_NAME_TO_CLASS[args.loss]

    for dataset, metric in dataset_to_metric.items():
        few_shot_train_splits, test_data = load_data_splits(dataset, args.sample_sizes, args.add_data_augmentation)
        print(f"Evaluating {dataset} using {metric!r}.")

        # Report on an imbalanced test set
        counter = Counter(test_data["label"])
        label_samples = sorted(counter.items(), key=lambda label_samples: label_samples[1])
        smallest_n_samples = label_samples[0][1]
        largest_n_samples = label_samples[-1][1]
        # If the largest class is more than 50% larger than the smallest
        if largest_n_samples > smallest_n_samples * 1.5:
            warnings.warn(
                "The test set has a class imbalance "
                f"({', '.join(f'label {label} w. {n_samples} samples' for label, n_samples in label_samples)})."
            )

        for split_name, train_data in few_shot_train_splits.items():
            results_path = create_results_path(dataset, split_name, output_path)
            if os.path.exists(results_path) and not args.override_results:
                print(f"Skipping finished experiment: {results_path}")
                continue

            # Load model
            if args.classifier == "pytorch":
                model = SetFitModel.from_pretrained(
                    args.model,
                    use_differentiable_head=True,
                    head_params={"out_features": len(set(train_data["label"]))},
                )
            else:
                model = SetFitModel.from_pretrained(args.model)
            model.model_body.max_seq_length = args.max_seq_length
            if args.add_normalization_layer:
                model.model_body._modules["2"] = models.Normalize()

            # Train on current split
            trainer = SetFitTrainer(
                model=model,
                train_dataset=train_data,
                eval_dataset=test_data,
                metric=metric,
                loss_class=loss_class,
                batch_size=args.batch_size,
                num_epochs=args.num_epochs,
                num_iterations=args.num_iterations,
            )
            if not args.eval_strategy:
                trainer.args.eval_strategy = "no"
            if args.classifier == "pytorch":
                trainer.freeze()
                trainer.train()
                trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
                trainer.train(
                    num_epochs=25,
                    body_learning_rate=1e-5,
                    learning_rate=args.lr,  # recommend: 1e-2
                    l2_weight=0.0,
                    batch_size=args.batch_size,
                )
            else:
                trainer.train()

            # Evaluate the model on the test data
            metrics = trainer.evaluate()
            print(f"Metrics: {metrics}")

            with open(results_path, "w") as f_out:
                json.dump(
                    {"score": metrics[metric] * 100, "measure": metric},
                    f_out,
                    sort_keys=True,
                )

    # Create a summary_table.csv file that computes means and standard deviations
    # for all of the results in `output_path`.
    create_summary_table(str(output_path))