in scripts/setfit/run_fewshot_distillation.py [0:0]
def train(self):
for dataset, metric in self.dataset_to_metric.items():
if self.mode == self.TEACHER:
print("\n\n\n=========== Training Teacher =========")
if self.mode == self.SETFIT_STUDENT:
print("\n\n\n======== Training SetFit Student ======")
if self.mode == self.BASELINE_STUDENT:
print("\n\n\n======== Training Baseline Student ======")
print(f"\n\n\n============== {dataset} ============")
# Load one of the SetFit training sets from the Hugging Face Hub
train_ds = load_dataset(f"SetFit/{dataset}", split="train")
eval_dataset = load_dataset(f"SetFit/{dataset}", split="test")
print(f"Test set: {len(eval_dataset)}")
# if teacher training use only 1 split (send only 1 seed. seed= 0)
if self.mode == self.TEACHER:
fewshot_ds = self.create_fewshot_splits(
train_ds,
self.args.teacher_sample_sizes,
seeds=TEACHER_SEED,
mode=self.TEACHER,
)
if self.mode == self.SETFIT_STUDENT:
fewshot_ds = self.create_fewshot_splits(
train_ds,
self.args.student_sample_sizes,
seeds=STUDENT_SEEDS,
mode=self.SETFIT_STUDENT,
)
self.student_train_dataset = fewshot_ds
# for training baseline student use the same data that was used for training setfit student
if self.mode == self.BASELINE_STUDENT:
fewshot_ds = self.student_train_dataset
num_classes = len(train_ds.unique("label"))
self.bl_stdnt_distill.update_metric(metric)
for name in fewshot_ds:
results_path = os.path.join(self.output_path, dataset, name, "results.json")
print(f"\n\n======== {os.path.dirname(results_path)} =======")
os.makedirs(os.path.dirname(results_path), exist_ok=True)
if self.mode == self.TEACHER:
teacher_model = SetFitModel.from_pretrained(self.model_name)
teacher_trainer = SetFitTrainer(
model=teacher_model,
train_dataset=fewshot_ds[name],
eval_dataset=eval_dataset,
loss_class=losses.CosineSimilarityLoss,
metric=metric,
batch_size=self.args.batch_size_teacher,
num_iterations=self.args.num_iterations_teacher, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
)
teacher_trainer.train()
# Evaluate the model on the test data
metrics = teacher_trainer.evaluate()
print("Teacher metrics: ", metrics)
self.teacher_train_dataset = fewshot_ds[name] # save teacher training data
self.trained_teacher_model = teacher_trainer.model
if self.mode == self.SETFIT_STUDENT:
# student train data = teacher train data + unlabeled data
student_train_dataset = concatenate_datasets([self.teacher_train_dataset, fewshot_ds[name]])
student_model = SetFitModel.from_pretrained(self.model_name)
student_trainer = DistillationSetFitTrainer(
teacher_model=self.trained_teacher_model,
train_dataset=student_train_dataset,
student_model=student_model,
eval_dataset=eval_dataset,
loss_class=losses.CosineSimilarityLoss,
metric="accuracy",
batch_size=self.args.batch_size_student,
num_iterations=self.args.num_iterations_student, # The number of text pairs to generate for contrastive learning
# column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
)
# Student Train and evaluate
student_trainer.train()
metrics = student_trainer.evaluate()
print("Student metrics: ", metrics)
if self.mode == self.BASELINE_STUDENT:
student_train_dataset = concatenate_datasets([self.teacher_train_dataset, fewshot_ds[name]])
metrics = self.train_baseline_student(student_train_dataset, eval_dataset, num_classes)
print("Baseline model score: ", round(metrics[metric] * 100, 3))
with open(results_path, "w") as f_out:
json.dump(
{"score": round(metrics[metric] * 100, 3), "measure": metric},
f_out,
sort_keys=True,
)