in src/sgd.py [0:0]
def __init__(self, params, lr, momentum=0, weight_decay=0, nesterov=False):
if lr < 0.0:
raise ValueError(f'Invalid learning rate: {lr}')
if momentum < 0.0:
raise ValueError(f'Invalid momentum value: {momentum}')
if weight_decay < 0.0:
raise ValueError(f'Invalid weight_decay value: {weight_decay}')
if nesterov and (momentum == 0.0):
raise ValueError(f'Nesterov needs momentum > 0')
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay,
nesterov=nesterov)
super(SGD, self).__init__(params, defaults)