def _get_multiplier()

in torchrec/optim/warmup.py [0:0]


def _get_multiplier(stage: WarmupStage, iter: int) -> float:
    multiplier = 1.0
    if stage.policy == WarmupPolicy.LINEAR:
        multiplier = stage.value + (1.0 - stage.value) * iter / stage.max_iters
    elif stage.policy == WarmupPolicy.CONSTANT:
        multiplier = stage.value
    elif stage.policy == WarmupPolicy.POLY:
        multiplier = math.pow(1 - iter / stage.decay_iters, stage.value)
    elif stage.policy == WarmupPolicy.STEP:
        multiplier = math.pow(stage.value, iter // stage.decay_iters)
    elif stage.policy == WarmupPolicy.INVSQRT:
        multiplier = 1.0 / math.sqrt(iter)
    return multiplier * stage.lr_scale