def train()

in PairSupCon/training.py [0:0]


    def train(self):
        all_iter = self.args.epochs * len(self.train_loader)
        print('\n={}/{}=Iterations/Batches'.format(all_iter, len(self.train_loader)))

        self.model.train()
        for epoch in range(self.args.epochs):
            for j, batch in enumerate(self.train_loader):

                input_ids, attention_mask, pairsimi = self.prepare_pairwise_input(batch)
                
                losses = self.train_step(input_ids, attention_mask, pairsimi)
                
                if (self.gstep%self.args.logging_step==0) or (self.gstep==all_iter) or (self.gstep==self.args.max_iter):

                    if self.args.dev_set=="sts": # testing the effectiveness of stsb as dev set, default False
                        self.model.eval()
                        sts_metrics = self.eval_stsdev()
                        losses.update(sts_metrics)

                        if sts_metrics["eval_stsb_spearman"] > self.dev_objective:
                            self.save_model(epoch, best_dev=True)
                            self.dev_objective = sts_metrics["eval_stsb_spearman"]
                        self.model.train()

                    statistics_log(self.args.tensorboard, losses=losses, global_step=self.gstep)
                        
                elif self.gstep > self.args.max_iter:
                    break
                    
                self.gstep += 1
            self.save_model(epoch, best_dev=False)
        return None