in torchrec/optim/warmup.py [0:0]
def _get_multiplier(stage: WarmupStage, iter: int) -> float:
multiplier = 1.0
if stage.policy == WarmupPolicy.LINEAR:
multiplier = stage.value + (1.0 - stage.value) * iter / stage.max_iters
elif stage.policy == WarmupPolicy.CONSTANT:
multiplier = stage.value
elif stage.policy == WarmupPolicy.POLY:
multiplier = math.pow(1 - iter / stage.decay_iters, stage.value)
elif stage.policy == WarmupPolicy.STEP:
multiplier = math.pow(stage.value, iter // stage.decay_iters)
elif stage.policy == WarmupPolicy.INVSQRT:
multiplier = 1.0 / math.sqrt(iter)
return multiplier * stage.lr_scale