tzrec/optim/lr_scheduler.py (85 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import bisect import math from typing import List from torch.optim.lr_scheduler import LRScheduler from torch.optim.optimizer import Optimizer from tzrec.utils.load_class import get_register_class_meta _LR_CLASS_MAP = {} _meta_cls = get_register_class_meta(_LR_CLASS_MAP) class BaseLR(LRScheduler, metaclass=_meta_cls): """LearningRate Scheduler base class.""" def __init__(self, optimizer: Optimizer, by_epoch: bool = False) -> None: self._by_epoch = by_epoch super().__init__(optimizer) @property def by_epoch(self) -> bool: """Schedule by epoch or not.""" return self._by_epoch class ConstantLR(BaseLR): """Constant LearningRate Scheduler.""" def __init__(self, optimizer: Optimizer) -> None: super().__init__(optimizer, by_epoch=True) # pyre-ignore [3] def get_lr(self): """Calculates the learning rate.""" return self.base_lrs class ExponentialDecayLR(BaseLR): """Exponential Decay LearningRate Scheduler. Args: optimizer (Optimizer): an instance of Optimizer. decay_size (int): decay steps or epochs. decay_factor (float): decay rate. staircase (bool): if true, decay the learning rate at discrete intervals. warmup_learning_rate (float): warmup start learning rate. warmup_size (int): warmup steps or epochs. min_learning_rate (float): minimum learning rate. by_epoch (bool): schedule by epoch or by step. """ def __init__( self, optimizer: Optimizer, decay_size: int, decay_factor: float, staircase: bool = True, warmup_learning_rate: float = 0.0, warmup_size: int = 0, min_learning_rate: float = 0.0, by_epoch: bool = False, ) -> None: self._decay_size = decay_size self._decay_factor = decay_factor self._staircase = staircase self._warmup_learning_rate = warmup_learning_rate self._warmup_size = warmup_size self._min_learning_rate = min_learning_rate super().__init__(optimizer, by_epoch=by_epoch) # pyre-ignore [3] def get_lr(self): """Calculates the learning rate.""" step_count = max(self._step_count - 1, 0) if step_count < self._warmup_size: scale = step_count / self._warmup_size lr = [ (base_lr - self._warmup_learning_rate) * scale + self._warmup_learning_rate for base_lr in self.base_lrs ] else: p = (step_count - self._warmup_size) / self._decay_size if self._staircase: p = math.floor(p) scale = math.pow(self._decay_factor, p) lr = [ max(base_lr * scale, self._min_learning_rate) for base_lr in self.base_lrs ] return lr class ManualStepLR(BaseLR): """Manual Step LearningRate Scheduler. Args: optimizer (Optimizer): an instance of Optimizer. schedule_steps (list): a list of global steps or epochs at which to switch learning. learning_rates (list): a list of learning rates corresponding to intervals. warmup (bool): whether to linearly interpolate learning rates for steps in [0, schedule_steps[0]]. by_epoch (bool): schedule by epoch or by step. """ def __init__( self, optimizer: Optimizer, schedule_sizes: List[int], learning_rates: List[float], warmup: bool = False, by_epoch: bool = False, ) -> None: self._schedule_sizes = schedule_sizes self._learning_rates = learning_rates self._warmup = warmup super().__init__(optimizer, by_epoch=by_epoch) # pyre-ignore [3] def get_lr(self): """Calculates the learning rate.""" step_count = max(self._step_count - 1, 0) idx = bisect.bisect_left(self._schedule_sizes, step_count) if idx > 0: lr = [self._learning_rates[idx - 1] for _ in self.base_lrs] elif self._warmup: scale = step_count / self._schedule_sizes[0] lr = [ (self._learning_rates[0] - base_lr) * scale + base_lr for base_lr in self.base_lrs ] else: lr = self.base_lrs return lr