in utils/optimizer.py [0:0]
def get_optimizer(parameters, opt_config, epochs):
"""
Parse optimizer parameters.
opt_config should be of the form:
- "sgd,lr=0.01"
- "adagrad,lr=0.1,lr_decay=0.05"
"""
lr_schedule = None
if "," in opt_config:
method = opt_config[:opt_config.find(',')]
optim_params = {}
for x in opt_config[opt_config.find(',') + 1:].split(','):
# e.g. split = ('lr', '0.1-0.01) or split = ('weight_decay', 0.001)
split = x.split('=')
assert len(split) == 2
param_name, param_value = split
assert any([
re.match(r"^[+-]?(\d+(\.\d*)?|\.\d+)$", param_value) is not None,
param_name == "lr" and re.match(r"^[+-]?(\d+(\.\d*)?|\.\d+)$", param_value) is not None,
param_name == "lr" and ("-" in param_value),
param_name == "lr" and re.match(r"^cos:[+-]?(\d+(\.\d*)?|\.\d+)$", param_value) is not None
])
if param_name == "lr":
if param_value.startswith("cos:"):
lr_init = float(param_value[4:])
lr_schedule = [lr_init * (1 + np.cos(np.pi * epoch / epochs)) / 2 for epoch in range(epochs)]
else:
lr_schedule = [float(lr) for lr in param_value.split("-")]
optim_params[param_name] = float(lr_schedule[0])
lr_schedule = repeat_to(lr_schedule, epochs)
else:
optim_params[param_name] = float(param_value)
else:
method = opt_config
optim_params = {}
if method == 'adadelta':
optim_fn = optim.Adadelta
elif method == 'adagrad':
optim_fn = optim.Adagrad
elif method == 'adam':
optim_fn = optim.Adam
optim_params['betas'] = (optim_params.get('beta1', 0.9), optim_params.get('beta2', 0.999))
optim_params.pop('beta1', None)
optim_params.pop('beta2', None)
elif method == 'adamax':
optim_fn = optim.Adamax
elif method == 'asgd':
optim_fn = optim.ASGD
elif method == 'rmsprop':
optim_fn = optim.RMSprop
elif method == 'rprop':
optim_fn = optim.Rprop
elif method == 'sgd':
optim_fn = optim.SGD
assert 'lr' in optim_params
else:
raise Exception('Unknown optimization method: "%s"' % method)
# check that we give good parameters to the optimizer
expected_args = inspect.getargspec(optim_fn.__init__)[0]
assert expected_args[:2] == ['self', 'params']
if not all(k in expected_args[2:] for k in optim_params.keys()):
raise Exception('Unexpected parameters: expected "%s", got "%s"' % (
str(expected_args[2:]), str(optim_params.keys())))
logger.info("Schedule of %s: %s" % (opt_config, str(lr_schedule)))
return optim_fn(parameters, **optim_params), lr_schedule