in scripts/setfit/run_fewshot_distillation.py [0:0]
def __init__(self, args, mode, trained_teacher_model, teacher_train_dataset, student_train_dataset) -> None:
# Prepare directory for results
self.args = args
# these attributes refer to the different modes to run the training
self.TEACHER = 0
self.SETFIT_STUDENT = 1
self.BASELINE_STUDENT = 2
if mode == self.TEACHER:
model = args.teacher_model
path_prefix = f"setfit_teacher_{args.teacher_model.replace('/', '-')}"
self.mode = self.TEACHER
if mode == self.SETFIT_STUDENT:
model = args.student_model
path_prefix = f"setfit_student_{args.student_model.replace('/', '-')}"
self.trained_teacher_model = trained_teacher_model
self.teacher_train_dataset = teacher_train_dataset
self.mode = self.SETFIT_STUDENT
if mode == self.BASELINE_STUDENT:
model = args.baseline_student_model
path_prefix = f"baseline_student_{args.student_model.replace('/', '-')}"
self.trained_teacher_model = trained_teacher_model
self.teacher_train_dataset = teacher_train_dataset
self.student_train_dataset = student_train_dataset
self.mode = self.BASELINE_STUDENT
self.bl_stdnt_distill = BaselineDistillation(
args.baseline_student_model,
args.baseline_model_epochs,
args.baseline_model_batch_size,
)
parent_directory = pathlib.Path(__file__).parent.absolute()
self.output_path = (
parent_directory
/ "results"
/ f"{path_prefix}-{args.loss}-{args.classifier}-student_iters_{args.num_iterations_student}-batch_{args.batch_size_student}-{args.exp_name}".rstrip(
"-"
)
)
os.makedirs(self.output_path, exist_ok=True)
# Save a copy of this training script and the run command in results directory
train_script_path = os.path.join(self.output_path, "train_script.py")
copyfile(__file__, train_script_path)
with open(train_script_path, "a") as f_out:
f_out.write("\n\n# Script was called via:\n#python " + " ".join(sys.argv))
# Configure dataset <> metric mapping. Defaults to accuracy
if args.is_dev_set:
self.dataset_to_metric = DEV_DATASET_TO_METRIC
elif args.is_test_set:
self.dataset_to_metric = TEST_DATASET_TO_METRIC
else:
self.dataset_to_metric = {dataset: "accuracy" for dataset in args.datasets}
# Configure loss function
self.loss_class = losses.CosineSimilarityLoss
self.model_name = model
# Load SetFit Model
self.model_wrapper = SetFitBaseModel(
# self.args.model, max_seq_length=args.max_seq_length, add_normalization_layer=args.add_normalization_layer
model,
max_seq_length=args.max_seq_length,
add_normalization_layer=args.add_normalization_layer,
)
self.model = self.model_wrapper.model