in scripts/setfit/run_zeroshot.py [0:0]
def main():
args = parse_args()
parent_directory = pathlib.Path(__file__).parent.absolute()
output_path = (
parent_directory
/ "results"
/ f"{args.model.replace('/', '-')}-{args.loss}-{args.classifier}-iterations_{args.num_iterations}-batch_{args.batch_size}-{args.exp_name}".rstrip(
"-"
)
)
os.makedirs(output_path, exist_ok=True)
# Save a copy of this training script and the run command in results directory
train_script_path = os.path.join(output_path, "train_script.py")
copyfile(__file__, train_script_path)
with open(train_script_path, "a") as f_out:
f_out.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
# Configure loss function
loss_class = LOSS_NAME_TO_CLASS[args.loss]
metric = DEV_DATASET_TO_METRIC.get(args.eval_dataset, TEST_DATASET_TO_METRIC.get(args.eval_dataset, "accuracy"))
if args.reference_dataset is None and args.candidate_labels is None:
args.reference_dataset = args.eval_dataset
train_data = get_templated_dataset(
reference_dataset=args.reference_dataset,
candidate_labels=args.candidate_labels,
sample_size=args.aug_sample_size,
label_names_column=args.label_names_column,
)
test_data = load_dataset(args.eval_dataset, split="test")
print(f"Evaluating {args.eval_dataset} using {metric!r}.")
# Report on an imbalanced test set
counter = Counter(test_data["label"])
label_samples = sorted(counter.items(), key=lambda label_samples: label_samples[1])
smallest_n_samples = label_samples[0][1]
largest_n_samples = label_samples[-1][1]
# If the largest class is more than 50% larger than the smallest
if largest_n_samples > smallest_n_samples * 1.5:
warnings.warn(
"The test set has a class imbalance "
f"({', '.join(f'label {label} w. {n_samples} samples' for label, n_samples in label_samples)})."
)
results_path = create_results_path(args.eval_dataset, "zeroshot", output_path)
if os.path.exists(results_path) and not args.override_results:
print(f"Skipping finished experiment: {results_path}")
exit()
# Load model
if args.classifier == "pytorch":
model = SetFitModel.from_pretrained(
args.model,
use_differentiable_head=True,
head_params={"out_features": len(set(train_data["label"]))},
)
else:
model = SetFitModel.from_pretrained(args.model)
model.model_body.max_seq_length = args.max_seq_length
if args.add_normalization_layer:
model.model_body._modules["2"] = models.Normalize()
# Train on current split
trainer = SetFitTrainer(
model=model,
train_dataset=train_data,
eval_dataset=test_data,
metric=metric,
loss_class=loss_class,
batch_size=args.batch_size,
num_epochs=args.num_epochs,
num_iterations=args.num_iterations,
)
if args.classifier == "pytorch":
trainer.freeze()
trainer.train()
trainer.unfreeze(keep_body_frozen=args.keep_body_frozen)
trainer.train(
num_epochs=25,
body_learning_rate=1e-5,
learning_rate=args.lr, # recommend: 1e-2
l2_weight=0.0,
batch_size=args.batch_size,
)
else:
trainer.train()
# Evaluate the model on the test data
metrics = trainer.evaluate()
print(f"Metrics: {metrics}")
with open(results_path, "w") as f_out:
json.dump(
{"score": metrics[metric] * 100, "measure": metric},
f_out,
sort_keys=True,
)