in scripts/setfit/run_fewshot.py [0:0]
def main():
args = parse_args()
parent_directory = pathlib.Path(__file__).parent.absolute()
output_path = (
parent_directory
/ "results"
/ f"{args.model.replace('/', '-')}-{args.loss}-{args.classifier}-iterations_{args.num_iterations}-batch_{args.batch_size}-{args.exp_name}".rstrip(
"-"
)
)
os.makedirs(output_path, exist_ok=True)
# Save a copy of this training script and the run command in results directory
train_script_path = os.path.join(output_path, "train_script.py")
copyfile(__file__, train_script_path)
with open(train_script_path, "a") as f_out:
f_out.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
# Configure dataset <> metric mapping. Defaults to accuracy
if args.is_dev_set:
dataset_to_metric = DEV_DATASET_TO_METRIC
elif args.is_test_set:
dataset_to_metric = TEST_DATASET_TO_METRIC
else:
dataset_to_metric = {dataset: "accuracy" for dataset in args.datasets}
# Configure loss function
loss_class = LOSS_NAME_TO_CLASS[args.loss]
for dataset, metric in dataset_to_metric.items():
few_shot_train_splits, test_data = load_data_splits(dataset, args.sample_sizes, args.add_data_augmentation)
print(f"Evaluating {dataset} using {metric!r}.")
# Report on an imbalanced test set
counter = Counter(test_data["label"])
label_samples = sorted(counter.items(), key=lambda label_samples: label_samples[1])
smallest_n_samples = label_samples[0][1]
largest_n_samples = label_samples[-1][1]
# If the largest class is more than 50% larger than the smallest
if largest_n_samples > smallest_n_samples * 1.5:
warnings.warn(
"The test set has a class imbalance "
f"({', '.join(f'label {label} w. {n_samples} samples' for label, n_samples in label_samples)})."
)
for split_name, train_data in few_shot_train_splits.items():
results_path = create_results_path(dataset, split_name, output_path)
if os.path.exists(results_path) and not args.override_results:
print(f"Skipping finished experiment: {results_path}")
continue
# Load model
if args.classifier == "pytorch":
model = SetFitModel.from_pretrained(
args.model,
use_differentiable_head=True,
head_params={"out_features": len(set(train_data["label"]))},
)
else:
model = SetFitModel.from_pretrained(args.model)
model.model_body.max_seq_length = args.max_seq_length
if args.add_normalization_layer:
model.model_body._modules["2"] = models.Normalize()
# Train on current split
trainer = SetFitTrainer(
model=model,
train_dataset=train_data,
eval_dataset=test_data,
metric=metric,
loss_class=loss_class,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
if not args.eval_strategy:
trainer.args.eval_strategy = "no"
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
trainer.train(
num_epochs=25,
body_learning_rate=1e-5,
learning_rate=args.lr, # recommend: 1e-2
l2_weight=0.0,
batch_size=args.batch_size,
)
else:
trainer.train()
# Evaluate the model on the test data
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")
with open(results_path, "w") as f_out:
json.dump(
{"score": metrics[metric] * 100, "measure": metric},
f_out,
sort_keys=True,
)
# Create a summary_table.csv file that computes means and standard deviations
# for all of the results in `output_path`.
create_summary_table(str(output_path))