in banding_removal/fastmri/learning_rate_mixin.py [0:0]
def start_of_epoch_hook(self, epoch):
args = self.args
momentum = args.momentum
decay = args.decay
if args.lr_reduction == "default":
lr = args.lr * (0.1 ** (epoch // 75))
elif args.lr_reduction == "none":
lr = args.lr
elif args.lr_reduction == "60-90":
lr = args.lr * (0.1 ** (epoch // 60)) * (0.1 ** (epoch // 90))
elif args.lr_reduction == "58":
lr = args.lr * (0.1 ** (epoch // 58))
elif args.lr_reduction == "80-120":
lr = args.lr * (0.1 ** (epoch // 80)) * (0.1 ** (epoch // 120))
elif args.lr_reduction == "150-225":
lr = args.lr * (0.1 ** (epoch // 150)) * (0.1 ** (epoch // 225))
elif args.lr_reduction == "40-50-55":
lr = args.lr * (0.1 ** (epoch // 40)) * (0.1 ** (epoch // 50)) * (0.1 ** (epoch // 55))
elif args.lr_reduction == "50-55-58":
lr = args.lr * (0.1 ** (epoch // 50)) * (0.1 ** (epoch // 55)) * (0.1 ** (epoch // 58))
elif args.lr_reduction == "linear":
lr = args.lr * (0.985 ** epoch)
elif args.lr_reduction == "every20":
lr = args.lr * (0.1 ** (epoch // 20))
elif args.lr_reduction == "every40":
lr = args.lr * (0.1 ** (epoch // 40))
elif args.lr_reduction == "rampupto":
if epoch == 0:
lr = args.lr*0.1
else:
lr = args.lr
else:
raise Exception(f"Lr scheme not recognised: {args.lr_reduction}")
for i, param_group in enumerate(self.optimizer.param_groups):
if 'group_scaling' not in param_group:
param_group['group_scaling'] = 1.0
group_lr = lr*param_group['group_scaling']
param_group['lr'] = group_lr
#param_group['momentum'] = momentum
#param_group['weight_decay'] = decay
logging.info(f"Group {i} Learning rate: {group_lr} momentum: {momentum} decay: {decay}")
super().start_of_epoch_hook(epoch)