in PairSupCon/training.py [0:0]
def train(self):
all_iter = self.args.epochs * len(self.train_loader)
print('\n={}/{}=Iterations/Batches'.format(all_iter, len(self.train_loader)))
self.model.train()
for epoch in range(self.args.epochs):
for j, batch in enumerate(self.train_loader):
input_ids, attention_mask, pairsimi = self.prepare_pairwise_input(batch)
losses = self.train_step(input_ids, attention_mask, pairsimi)
if (self.gstep%self.args.logging_step==0) or (self.gstep==all_iter) or (self.gstep==self.args.max_iter):
if self.args.dev_set=="sts": # testing the effectiveness of stsb as dev set, default False
self.model.eval()
sts_metrics = self.eval_stsdev()
losses.update(sts_metrics)
if sts_metrics["eval_stsb_spearman"] > self.dev_objective:
self.save_model(epoch, best_dev=True)
self.dev_objective = sts_metrics["eval_stsb_spearman"]
self.model.train()
statistics_log(self.args.tensorboard, losses=losses, global_step=self.gstep)
elif self.gstep > self.args.max_iter:
break
self.gstep += 1
self.save_model(epoch, best_dev=False)
return None