in slowfast/models/optimizer.py [0:0]
def construct_optimizer(model, cfg):
"""
Construct a stochastic gradient descent or ADAM optimizer with momentum.
Details can be found in:
Herbert Robbins, and Sutton Monro. "A stochastic approximation method."
and
Diederik P.Kingma, and Jimmy Ba.
"Adam: A Method for Stochastic Optimization."
Args:
model (model): model to perform stochastic gradient descent
optimization or ADAM optimization.
cfg (config): configs of hyper-parameters of SGD or ADAM, includes base
learning rate, momentum, weight_decay, dampening, and etc.
"""
bn_parameters = []
non_bn_parameters = []
zero_parameters = []
no_grad_parameters = []
skip = {}
if cfg.NUM_GPUS > 1:
if hasattr(model.module, "no_weight_decay"):
skip = model.module.no_weight_decay()
skip = {"module." + v for v in skip}
else:
if hasattr(model, "no_weight_decay"):
skip = model.no_weight_decay()
for name, m in model.named_modules():
is_bn = isinstance(m, torch.nn.modules.batchnorm._NormBase)
for p in m.parameters(recurse=False):
if not p.requires_grad:
no_grad_parameters.append(p)
elif is_bn:
bn_parameters.append(p)
elif name in skip:
zero_parameters.append(p)
elif cfg.SOLVER.ZERO_WD_1D_PARAM and \
(len(p.shape) == 1 or name.endswith(".bias")):
zero_parameters.append(p)
else:
non_bn_parameters.append(p)
optim_params = [
{"params": bn_parameters, "weight_decay": cfg.BN.WEIGHT_DECAY},
{"params": non_bn_parameters, "weight_decay": cfg.SOLVER.WEIGHT_DECAY},
{"params": zero_parameters, "weight_decay": 0.0},
]
optim_params = [x for x in optim_params if len(x["params"])]
# Check all parameters will be passed into optimizer.
assert len(list(model.parameters())) == len(non_bn_parameters) + len(
bn_parameters
) + len(zero_parameters) + len(
no_grad_parameters
), "parameter size does not match: {} + {} + {} + {} != {}".format(
len(non_bn_parameters),
len(bn_parameters),
len(zero_parameters),
len(no_grad_parameters),
len(list(model.parameters())),
)
print(
"bn {}, non bn {}, zero {} no grad {}".format(
len(bn_parameters),
len(non_bn_parameters),
len(zero_parameters),
len(no_grad_parameters),
)
)
if cfg.SOLVER.OPTIMIZING_METHOD == "sgd":
return torch.optim.SGD(
optim_params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
dampening=cfg.SOLVER.DAMPENING,
nesterov=cfg.SOLVER.NESTEROV,
)
elif cfg.SOLVER.OPTIMIZING_METHOD == "adam":
return torch.optim.Adam(
optim_params,
lr=cfg.SOLVER.BASE_LR,
betas=(0.9, 0.999),
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
elif cfg.SOLVER.OPTIMIZING_METHOD == "adamw":
return torch.optim.AdamW(
optim_params,
lr=cfg.SOLVER.BASE_LR,
eps=1e-08,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
else:
raise NotImplementedError(
"Does not support {} optimizer".format(cfg.SOLVER.OPTIMIZING_METHOD)
)