in models.py [0:0]
def compute_loss_value_(self, i, x, y, g, epoch):
if epoch == self.hparams["T"] + 1 and\
self.last_epoch == self.hparams["T"]:
self.init_model_(self.data_type, text_optim="adamw")
predictions = self.network(x)
if epoch != self.hparams["T"]:
loss_value = self.loss(predictions, y).mean()
else:
self.eval()
if predictions.squeeze().ndim == 1:
predictions = (predictions > 0).cpu().eq(y).float()
else:
predictions = predictions.argmax(1).cpu().eq(y).float()
self.weights[i] += predictions.detach() * (self.hparams["up"] - 1)
self.train()
loss_value = None
return loss_value