def compute_loss_value_()

in models.py [0:0]


    def compute_loss_value_(self, i, x, y, g, epoch):
        if epoch == self.hparams["T"] + 1 and\
           self.last_epoch == self.hparams["T"]:
            self.init_model_(self.data_type, text_optim="adamw")

        predictions = self.network(x)

        if epoch != self.hparams["T"]:
            loss_value = self.loss(predictions, y).mean()
        else:
            self.eval()
            if predictions.squeeze().ndim == 1:
                predictions = (predictions > 0).cpu().eq(y).float()
            else:
                predictions = predictions.argmax(1).cpu().eq(y).float()

            self.weights[i] += predictions.detach() * (self.hparams["up"] - 1)
            self.train()
            loss_value = None

        return loss_value