def infer_learning_rate()

in vissl/utils/hydra_config.py [0:0]


def infer_learning_rate(cfg):
    """
    1) Assert the Learning rate here. LR is scaled as per https://arxiv.org/abs/1706.02677.
    to turn this automatic scaling off,
    set config.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.auto_scale=false

    scaled_lr is calculated:
        given base_lr_batch_size = batch size for which the base learning rate is specified,
              base_value = base learning rate value that will be scaled,
              The current batch size is used to determine how to scale the base learning rate
              value.
        scale_factor = (batchsize_per_gpu * world_size) / base_lr_batch_size
        if scaling_type is sqrt, scale factor = sqrt(scale_factor)
        scaled_lr = scale_factor * base_value


    We perform this auto-scaling for head learning rate as well if user wants to use a different
    learning rate for the head

    2) infer the model head params weight decay: if the head should use a different weight
       decay value than the trunk.
       If using different weight decay value for the head, set here. otherwise, the
       same value as trunk will be automatically used.
    """
    if cfg.OPTIMIZER.param_schedulers.lr.auto_lr_scaling.auto_scale:
        world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA * world_size
        param_schedulers = cfg.OPTIMIZER.param_schedulers.lr
        base_lr = param_schedulers.auto_lr_scaling.base_value
        base_lr_batch_size = param_schedulers.auto_lr_scaling.base_lr_batch_size
        scaling_type = param_schedulers.auto_lr_scaling.scaling_type
        assert scaling_type in [
            "sqrt",
            "linear",
        ], "Only linear | sqrt scaling_types are supported"

        scale_factor = float(batch_size) / base_lr_batch_size
        if scaling_type == "sqrt":
            scale_factor = scale_factor ** 0.5
        scaled_lr = base_lr * scale_factor
        cfg.OPTIMIZER.param_schedulers.lr = get_scaled_lr_scheduler(
            cfg, param_schedulers, scaled_lr
        )

    if not cfg.OPTIMIZER.head_optimizer_params.use_different_lr:
        # if not using the different value for the head, we set the weight decay and LR
        # param scheduler same as the trunk.
        cfg.OPTIMIZER.param_schedulers.lr_head = cfg.OPTIMIZER.param_schedulers.lr
    elif (
        cfg.OPTIMIZER.head_optimizer_params.use_different_lr
        and cfg.OPTIMIZER.param_schedulers.lr_head
        and cfg.OPTIMIZER.param_schedulers.lr_head.auto_lr_scaling.auto_scale
    ):
        # if the user wants a different LR value for the head, then we
        # automatically infer the LR values for the head as well (similar to
        # trunk above)
        world_size = cfg.DISTRIBUTED.NUM_NODES * cfg.DISTRIBUTED.NUM_PROC_PER_NODE
        batch_size = cfg.DATA.TRAIN.BATCHSIZE_PER_REPLICA * world_size
        param_schedulers = cfg.OPTIMIZER.param_schedulers.lr_head
        base_lr = param_schedulers.auto_lr_scaling.base_value
        base_lr_batch_size = param_schedulers.auto_lr_scaling.base_lr_batch_size
        scaling_type = param_schedulers.auto_lr_scaling.scaling_type
        assert scaling_type in [
            "sqrt",
            "linear",
        ], "Only linear | sqrt scaling_types are supported"

        scale_factor = float(batch_size) / base_lr_batch_size
        if scaling_type == "sqrt":
            scale_factor = scale_factor ** 0.5
        scaled_lr = base_lr * scale_factor
        cfg.OPTIMIZER.param_schedulers.lr_head = get_scaled_lr_scheduler(
            cfg, param_schedulers, scaled_lr
        )

    # for the head, if we want to use a different weight decay value,
    # we verify that the specified weight decay value is valid. Otherwise,
    # we do the inference and set the weight decay value same as the trunk.
    if not cfg.OPTIMIZER.head_optimizer_params.use_different_wd:
        cfg.OPTIMIZER.head_optimizer_params.weight_decay = cfg.OPTIMIZER.weight_decay
    else:
        assert (
            cfg.OPTIMIZER.head_optimizer_params.weight_decay >= 0.0
        ), "weight decay for head should be >=0"
    return cfg