in python/singa/opt.py [0:0]
def __init__(self,
lr=0.1,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-8,
weight_decay=0):
super(Adam, self).__init__(lr)
# init weight_decay
if type(weight_decay) == float or type(weight_decay) == int:
if weight_decay < 0.0:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay))
self.weight_decay = Constant(weight_decay)
elif isinstance(weight_decay, DecayScheduler):
self.weight_decay = weight_decay
else:
raise TypeError("Wrong weight_decay type")
self.decay_value = self.weight_decay(self.step_counter)
# init beta_1
if type(beta_1) == float or type(beta_1) == int:
self.beta_1 = Constant(beta_1)
elif isinstance(beta_1, DecayScheduler):
self.beta_1 = beta_1
else:
raise TypeError("Wrong beta_1 type")
self.beta_1_value = self.beta_1(self.step_counter)
# init beta_2
if type(beta_2) == float or type(beta_2) == int:
self.beta_2 = Constant(beta_2)
elif isinstance(beta_2, DecayScheduler):
self.beta_2 = beta_2
else:
raise TypeError("Wrong beta_2 type")
self.beta_2_value = self.beta_2(self.step_counter)
# init epsilon
if type(epsilon) == float or type(epsilon) == int:
self.epsilon = Constant(epsilon)
elif isinstance(epsilon, DecayScheduler):
self.epsilon = epsilon
else:
raise TypeError("Wrong epsilon type")
self.epsilon_value = self.epsilon(self.step_counter)
# init m and v
self.m = dict()
self.v = dict()