def train()

in scripts/setfit/run_fewshot_distillation.py [0:0]


    def train(self):
        for dataset, metric in self.dataset_to_metric.items():
            if self.mode == self.TEACHER:
                print("\n\n\n=========== Training Teacher =========")
            if self.mode == self.SETFIT_STUDENT:
                print("\n\n\n======== Training SetFit Student ======")
            if self.mode == self.BASELINE_STUDENT:
                print("\n\n\n======== Training Baseline Student ======")
            print(f"\n\n\n============== {dataset} ============")

            # Load one of the SetFit training sets from the Hugging Face Hub
            train_ds = load_dataset(f"SetFit/{dataset}", split="train")
            eval_dataset = load_dataset(f"SetFit/{dataset}", split="test")
            print(f"Test set: {len(eval_dataset)}")

            # if teacher training use only 1 split (send only 1 seed. seed= 0)
            if self.mode == self.TEACHER:
                fewshot_ds = self.create_fewshot_splits(
                    train_ds,
                    self.args.teacher_sample_sizes,
                    seeds=TEACHER_SEED,
                    mode=self.TEACHER,
                )

            if self.mode == self.SETFIT_STUDENT:
                fewshot_ds = self.create_fewshot_splits(
                    train_ds,
                    self.args.student_sample_sizes,
                    seeds=STUDENT_SEEDS,
                    mode=self.SETFIT_STUDENT,
                )
                self.student_train_dataset = fewshot_ds

            # for training baseline student use the same data that was used for training setfit student
            if self.mode == self.BASELINE_STUDENT:
                fewshot_ds = self.student_train_dataset
                num_classes = len(train_ds.unique("label"))
                self.bl_stdnt_distill.update_metric(metric)

            for name in fewshot_ds:
                results_path = os.path.join(self.output_path, dataset, name, "results.json")
                print(f"\n\n======== {os.path.dirname(results_path)} =======")
                os.makedirs(os.path.dirname(results_path), exist_ok=True)

                if self.mode == self.TEACHER:
                    teacher_model = SetFitModel.from_pretrained(self.model_name)
                    teacher_trainer = SetFitTrainer(
                        model=teacher_model,
                        train_dataset=fewshot_ds[name],
                        eval_dataset=eval_dataset,
                        loss_class=losses.CosineSimilarityLoss,
                        metric=metric,
                        batch_size=self.args.batch_size_teacher,
                        num_iterations=self.args.num_iterations_teacher,  # The number of text pairs to generate for contrastive learning
                        num_epochs=1,  # The number of epochs to use for contrastive learning
                    )
                    teacher_trainer.train()

                    # Evaluate the model on the test data
                    metrics = teacher_trainer.evaluate()
                    print("Teacher metrics: ", metrics)

                    self.teacher_train_dataset = fewshot_ds[name]  # save teacher training data
                    self.trained_teacher_model = teacher_trainer.model

                if self.mode == self.SETFIT_STUDENT:
                    # student train data = teacher train data + unlabeled data
                    student_train_dataset = concatenate_datasets([self.teacher_train_dataset, fewshot_ds[name]])

                    student_model = SetFitModel.from_pretrained(self.model_name)
                    student_trainer = DistillationSetFitTrainer(
                        teacher_model=self.trained_teacher_model,
                        train_dataset=student_train_dataset,
                        student_model=student_model,
                        eval_dataset=eval_dataset,
                        loss_class=losses.CosineSimilarityLoss,
                        metric="accuracy",
                        batch_size=self.args.batch_size_student,
                        num_iterations=self.args.num_iterations_student,  # The number of text pairs to generate for contrastive learning
                        # column_mapping={"sentence": "text", "label": "label"} # Map dataset columns to text/label expected by trainer
                    )
                    # Student Train and evaluate
                    student_trainer.train()
                    metrics = student_trainer.evaluate()

                    print("Student metrics: ", metrics)

                if self.mode == self.BASELINE_STUDENT:
                    student_train_dataset = concatenate_datasets([self.teacher_train_dataset, fewshot_ds[name]])
                    metrics = self.train_baseline_student(student_train_dataset, eval_dataset, num_classes)
                    print("Baseline model score: ", round(metrics[metric] * 100, 3))

                with open(results_path, "w") as f_out:
                    json.dump(
                        {"score": round(metrics[metric] * 100, 3), "measure": metric},
                        f_out,
                        sort_keys=True,
                    )