in src/utils.py [0:0]
def update_lr(self, loss):
'''
update the learning rate based on the current loss value and historic loss values
:param loss: the loss after the current iteration
'''
if loss > self.last_epoch_loss and self.decay < 1.0 and self.total_decay > self.max_decay:
self.total_decay = self.total_decay * self.decay
print(f"NewbobAdam: Decay learning rate (loss degraded from {self.last_epoch_loss} to {loss})."
f"Total decay: {self.total_decay}")
# restore previous network state
self.net.load(self.artifacts_dir, suffix="newbob")
# decrease learning rate
for param_group in self.param_groups:
param_group['lr'] = param_group['lr'] * self.decay
else:
self.last_epoch_loss = loss
# save last snapshot to restore it in case of lr decrease
if self.decay < 1.0 and self.total_decay > self.max_decay:
self.net.save(self.artifacts_dir, suffix="newbob")