def train_epoch()

in backup/train.py [0:0]


    def train_epoch(self, epoch=0):
        """
        Train a binary predictor
        :return:
        """
        self.model.train()
        # X, Y, Y_hat = self.data.prepare_data(self.data.train_indices, self.args.vector_mode)
        loss_a = []
        t_loss_a = []
        f_loss_a = []
        # num_batches = len(list(range(0, len(X), self.args.batch_size)))
        dl = self.data.get_dataloader(mode="train", epoch=epoch)
        for i, batch in enumerate(dl):
            # for i in range(0, len(X), self.args.batch_size):
            # inp_vec, inp_len = batchify(X[i:i+self.args.batch_size], self.args.vector_mode)
            # outp_vec, outp_len = batchify(Y[i:i+self.args.batch_size], self.args.vector_mode)
            inp, inp_len, y_true, y_false = batch
            # inp_vec = [self.data.pca_predict([self.extract_sentence_bert(batch)])[0] for batch in inp]
            # inp_vec, _ = batchify(inp_vec, vector_mode=True)
            # inp_vec = inp_vec.to(self.device)
            # y_true = self.data.pca_predict([self.extract_sentence_bert(y_true)])[0]
            # y_true = torch.stack(y_true, dim=0)
            # y_true = y_true.to(self.device)
            # y_false = self.data.pca_predict([self.extract_sentence_bert(y_false)])[0]
            # y_false = torch.stack(y_false, dim=0)
            # y_false = y_false.to(self.device)
            pred_true = self.model(inp, inp_len, y_true)
            pred_false = self.model(inp, inp_len, y_false)
            # import pdb; pdb.set_trace()
            t_loss = self.loss_fn(
                pred_true, torch.ones(pred_true.size(0), 1).to(self.device)
            )
            f_loss = self.loss_fn(
                pred_false, torch.zeros(pred_false.size(0), 1).to(self.device)
            )
            loss = t_loss + f_loss
            self.optimizer.zero_grad()
            loss.backward()
            loss_a.append(loss.item())
            t_loss_a.append(t_loss.item())
            f_loss_a.append(f_loss.item())
            if i % self.args.log_interval == 0:
                metrics = {
                    "mode": "train",
                    "minibatch": self.train_step,
                    "loss": np.mean(loss_a),
                    "true_loss": np.mean(t_loss_a),
                    "false_loss": np.mean(f_loss_a),
                    "epoch": epoch,
                }
                self.train_step += 1
                loss_a = []
                t_loss_a = []
                f_loss_a = []
                self.logbook.write_metric_logs(metrics)
            self.optimizer.step()
        # post epoch
        metrics = {
            "mode": "train",
            "minibatch": self.train_step,
            "loss": np.mean(loss_a),
            "true_loss": np.mean(t_loss_a),
            "false_loss": np.mean(f_loss_a),
            "epoch": epoch,
        }
        self.train_step += 1
        self.logbook.write_metric_logs(metrics)