def get_lr_scheduler()

in src/util.py [0:0]


def get_lr_scheduler(config, optimizer):
    """
    Returns a bool of (update_lr_per_step, lr_scheduler)

    """
    lr = float(config.get("TRAIN", "learning_rate"))
    scheduler_type = config.get("TRAIN", "scheduler", fallback="plateau")

    if scheduler_type == "plateau":
        eps_tolerance = float(config.get("TRAIN", "eps_tolerance", fallback='0'))
        patience = int(config.get("TRAIN", "patience", fallback='1'))
        decay_factor = float(config.get("TRAIN", "decay_factor", fallback='0.5'))
        scheduler = ReduceLROnPlateau(optimizer,
                                      factor=decay_factor,
                                      patience=patience,
                                      eps=eps_tolerance)
        update_lr_per_step = False
    elif scheduler_type == "one_cycle":
        max_train_steps = int(config.get("TRAIN", "max_train_steps"))
        anneal_strategy = config.get("TRAIN", "anneal_strategy", fallback="cos")
        pct_start = float(config.get("TRAIN", "pct_start", fallback=0.3))
        scheduler = OneCycleLR(optimizer,
                               max_lr=lr,
                               total_steps=max_train_steps,
                               anneal_strategy=anneal_strategy,
                               pct_start=pct_start)
        update_lr_per_step = True
    else:
        raise Exception(f"Invalid scheduler type: {scheduler_type}")

    return (update_lr_per_step, scheduler)