Dassl.pytorch/dassl/optim/lr_scheduler.py (114 lines of code) (raw):
"""
Modified from https://github.com/KaiyangZhou/deep-person-reid
"""
import torch
from torch.optim.lr_scheduler import _LRScheduler
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"]
class _BaseWarmupScheduler(_LRScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
last_epoch=-1,
verbose=False
):
self.successor = successor
self.warmup_epoch = warmup_epoch
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
raise NotImplementedError
def step(self, epoch=None):
if self.last_epoch >= self.warmup_epoch:
self.successor.step(epoch)
self._last_lr = self.successor.get_last_lr()
else:
super().step(epoch)
class ConstantWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
cons_lr,
last_epoch=-1,
verbose=False
):
self.cons_lr = cons_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
return [self.cons_lr for _ in self.base_lrs]
class LinearWarmupScheduler(_BaseWarmupScheduler):
def __init__(
self,
optimizer,
successor,
warmup_epoch,
min_lr,
last_epoch=-1,
verbose=False
):
self.min_lr = min_lr
super().__init__(
optimizer, successor, warmup_epoch, last_epoch, verbose
)
def get_lr(self):
if self.last_epoch >= self.warmup_epoch:
return self.successor.get_last_lr()
if self.last_epoch == 0:
return [self.min_lr for _ in self.base_lrs]
return [
lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs
]
def build_lr_scheduler(optimizer, optim_cfg):
"""A function wrapper for building a learning rate scheduler.
Args:
optimizer (Optimizer): an Optimizer.
optim_cfg (CfgNode): optimization config.
"""
lr_scheduler = optim_cfg.LR_SCHEDULER
stepsize = optim_cfg.STEPSIZE
gamma = optim_cfg.GAMMA
max_epoch = optim_cfg.MAX_EPOCH
if lr_scheduler not in AVAI_SCHEDS:
raise ValueError(
f"scheduler must be one of {AVAI_SCHEDS}, but got {lr_scheduler}"
)
if lr_scheduler == "single_step":
if isinstance(stepsize, (list, tuple)):
stepsize = stepsize[-1]
if not isinstance(stepsize, int):
raise TypeError(
"For single_step lr_scheduler, stepsize must "
f"be an integer, but got {type(stepsize)}"
)
if stepsize <= 0:
stepsize = max_epoch
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=stepsize, gamma=gamma
)
elif lr_scheduler == "multi_step":
if not isinstance(stepsize, (list, tuple)):
raise TypeError(
"For multi_step lr_scheduler, stepsize must "
f"be a list, but got {type(stepsize)}"
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=stepsize, gamma=gamma
)
elif lr_scheduler == "cosine":
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(max_epoch)
)
if optim_cfg.WARMUP_EPOCH > 0:
if not optim_cfg.WARMUP_RECOUNT:
scheduler.last_epoch = optim_cfg.WARMUP_EPOCH
if optim_cfg.WARMUP_TYPE == "constant":
scheduler = ConstantWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_CONS_LR
)
elif optim_cfg.WARMUP_TYPE == "linear":
scheduler = LinearWarmupScheduler(
optimizer, scheduler, optim_cfg.WARMUP_EPOCH,
optim_cfg.WARMUP_MIN_LR
)
else:
raise ValueError
return scheduler