in python/singa/opt.py [0:0]
def __init__(self,
lr=0.1,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
dtype=tensor.float32):
super(SGD, self).__init__(lr, dtype)
# init momentum
if type(momentum) == float or type(momentum) == int:
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
self.momentum = Constant(momentum)
elif isinstance(momentum, DecayScheduler):
self.momentum = momentum
momentum = momentum.init_value
else:
raise TypeError("Wrong momentum type")
self.mom_value = self.momentum(self.step_counter).as_type(self.dtype)
# init dampening
if type(dampening) == float or type(dampening) == int:
self.dampening = Constant(dampening)
elif isinstance(dampening, DecayScheduler):
self.dampening = dampening
dampening = dampening.init_value
else:
raise TypeError("Wrong dampening type")
self.dam_value = self.dampening(self.step_counter).as_type(self.dtype)
# init weight_decay
if type(weight_decay) == float or type(weight_decay) == int:
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
self.weight_decay = Constant(weight_decay)
elif isinstance(weight_decay, DecayScheduler):
self.weight_decay = weight_decay
else:
raise TypeError("Wrong weight_decay type")
self.decay_value = self.weight_decay(self.step_counter).as_type(
self.dtype)
# init other params
self.nesterov = nesterov
self.moments = dict()
# check value
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening")