Dassl.pytorch/dassl/optim/optimizer.py (112 lines of code) (raw):

""" Modified from https://github.com/KaiyangZhou/deep-person-reid """ import warnings import torch import torch.nn as nn from .radam import RAdam AVAI_OPTIMS = ["adam", "amsgrad", "sgd", "rmsprop", "radam", "adamw"] def build_optimizer(model, optim_cfg, param_groups=None): """A function wrapper for building an optimizer. Args: model (nn.Module or iterable): model. optim_cfg (CfgNode): optimization config. param_groups: If provided, directly optimize param_groups and abandon model """ optim = optim_cfg.NAME lr = optim_cfg.LR weight_decay = optim_cfg.WEIGHT_DECAY momentum = optim_cfg.MOMENTUM sgd_dampening = optim_cfg.SGD_DAMPNING sgd_nesterov = optim_cfg.SGD_NESTEROV rmsprop_alpha = optim_cfg.RMSPROP_ALPHA adam_beta1 = optim_cfg.ADAM_BETA1 adam_beta2 = optim_cfg.ADAM_BETA2 staged_lr = optim_cfg.STAGED_LR new_layers = optim_cfg.NEW_LAYERS base_lr_mult = optim_cfg.BASE_LR_MULT if optim not in AVAI_OPTIMS: raise ValueError( f"optim must be one of {AVAI_OPTIMS}, but got {optim}" ) if param_groups is not None and staged_lr: warnings.warn( "staged_lr will be ignored, if you need to use staged_lr, " "please bind it with param_groups yourself." ) if param_groups is None: if staged_lr: if not isinstance(model, nn.Module): raise TypeError( "When staged_lr is True, model given to " "build_optimizer() must be an instance of nn.Module" ) if isinstance(model, nn.DataParallel): model = model.module if isinstance(new_layers, str): if new_layers is None: warnings.warn("new_layers is empty (staged_lr is useless)") new_layers = [new_layers] base_params = [] base_layers = [] new_params = [] for name, module in model.named_children(): if name in new_layers: new_params += [p for p in module.parameters()] else: base_params += [p for p in module.parameters()] base_layers.append(name) param_groups = [ { "params": base_params, "lr": lr * base_lr_mult }, { "params": new_params }, ] else: if isinstance(model, nn.Module): param_groups = model.parameters() else: param_groups = model if optim == "adam": optimizer = torch.optim.Adam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) elif optim == "amsgrad": optimizer = torch.optim.Adam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), amsgrad=True, ) elif optim == "sgd": optimizer = torch.optim.SGD( param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, dampening=sgd_dampening, nesterov=sgd_nesterov, ) elif optim == "rmsprop": optimizer = torch.optim.RMSprop( param_groups, lr=lr, momentum=momentum, weight_decay=weight_decay, alpha=rmsprop_alpha, ) elif optim == "radam": optimizer = RAdam( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) elif optim == "adamw": optimizer = torch.optim.AdamW( param_groups, lr=lr, weight_decay=weight_decay, betas=(adam_beta1, adam_beta2), ) else: raise NotImplementedError(f"Optimizer {optim} not implemented yet!") return optimizer