def train_epoch()

in backup/baselines/ruber/train.py [0:0]


    def train_epoch(self, epoch=0):
        """
        Train the unreferenced metric
        :return:
        """
        self.model.train()
        X, Y, Y_hat = self.data.prepare_data(self.data.train_indices, False)
        loss_a = []
        t_loss_a = []
        f_loss_a = []
        num_batches = len(list(range(0, len(X), self.args.batch_size)))
        for i in range(0, len(X), self.args.batch_size):
            inp_vec, inp_len = batchify(X[i : i + self.args.batch_size], False)
            outp_vec, outp_len = batchify(Y[i : i + self.args.batch_size], False)
            inp_vec = inp_vec.to(self.device)
            outp_vec = outp_vec.to(self.device)
            diff_true = self.model(inp_vec, inp_len, outp_vec, outp_len)
            y_false, y_len = batchify(Y_hat[i : i + self.args.batch_size], False)
            y_false = y_false.to(self.device)
            diff_false = self.model(inp_vec, inp_len, y_false, y_len)
            # import pdb; pdb.set_trace()
            loss = torch.clamp(
                torch.ones_like(diff_true) * self.args.margin - diff_true + diff_false,
                min=0.0,
            )
            loss = loss.mean()
            self.optimizer.zero_grad()
            loss.backward()
            loss_a.append(loss.item())
            t_loss_a.append(diff_true.mean().item())
            f_loss_a.append(diff_false.mean().item())
            if i % self.args.log_interval == 0 or (i + 1) > (
                len(X) - self.args.batch_size
            ):
                print(i)
                metrics = {
                    "mode": "train",
                    "minibatch": self.train_step,
                    "loss": np.mean(loss_a),
                    "true_loss": np.mean(t_loss_a),
                    "false_loss": np.mean(f_loss_a),
                    "epoch": epoch,
                }
                self.train_step += 1
                loss_a = []
                t_loss_a = []
                f_loss_a = []
                self.logbook.write_metric_logs(metrics)
            self.optimizer.step()