in src/optim.py [0:0]
def get_optimizer(parameters, s):
"""
Parse optimizer parameters.
Input should be of the form:
- "sgd,lr=0.01"
- "adagrad,lr=0.1,lr_decay=0.05"
"""
if "," in s:
method = s[: s.find(",")]
optim_params = {}
for x in s[s.find(",") + 1 :].split(","):
split = x.split("=")
assert len(split) == 2
assert re.match(r"^[+-]?(\d+(\.\d*)?|\.\d+)$", split[1]) is not None
optim_params[split[0]] = float(split[1])
else:
method = s
optim_params = {}
if method == "adadelta":
optim_fn = optim.Adadelta
elif method == "adagrad":
optim_fn = optim.Adagrad
elif method == "adam":
optim_fn = Adam
optim_params["betas"] = (
optim_params.get("beta1", 0.9),
optim_params.get("beta2", 0.999),
)
optim_params.pop("beta1", None)
optim_params.pop("beta2", None)
elif method == "adam_inverse_sqrt":
optim_fn = AdamInverseSqrtWithWarmup
optim_params["betas"] = (
optim_params.get("beta1", 0.9),
optim_params.get("beta2", 0.999),
)
optim_params.pop("beta1", None)
optim_params.pop("beta2", None)
elif method == "adam_cosine":
optim_fn = AdamCosineWithWarmup
optim_params["betas"] = (
optim_params.get("beta1", 0.9),
optim_params.get("beta2", 0.999),
)
optim_params.pop("beta1", None)
optim_params.pop("beta2", None)
elif method == "adamax":
optim_fn = optim.Adamax
elif method == "asgd":
optim_fn = optim.ASGD
elif method == "rmsprop":
optim_fn = optim.RMSprop
elif method == "rprop":
optim_fn = optim.Rprop
elif method == "sgd":
optim_fn = optim.SGD
assert "lr" in optim_params
else:
raise Exception('Unknown optimization method: "%s"' % method)
# check that we give good parameters to the optimizer
expected_args = inspect.getargspec(optim_fn.__init__)[0]
assert expected_args[:2] == ["self", "params"]
if not all(k in expected_args[2:] for k in optim_params.keys()):
raise Exception(
'Unexpected parameters: expected "%s", got "%s"'
% (str(expected_args[2:]), str(optim_params.keys()))
)
return optim_fn(parameters, **optim_params)