in models.py [0:0]
def update(self, i, x, y, g, epoch):
x, y, g = x.cuda(), y.cuda(), g.cuda()
loss_value = self.compute_loss_value_(i, x, y, g, epoch)
if loss_value is not None:
self.optimizer.zero_grad()
loss_value.backward()
if self.clip_grad:
torch.nn.utils.clip_grad_norm_(self.network.parameters(), 1.0)
self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if self.data_type == "text":
self.network.zero_grad()
loss_value = loss_value.item()
self.last_epoch = epoch
return loss_value