in src/util.py [0:0]
def get_lr_scheduler(config, optimizer):
"""
Returns a bool of (update_lr_per_step, lr_scheduler)
"""
lr = float(config.get("TRAIN", "learning_rate"))
scheduler_type = config.get("TRAIN", "scheduler", fallback="plateau")
if scheduler_type == "plateau":
eps_tolerance = float(config.get("TRAIN", "eps_tolerance", fallback='0'))
patience = int(config.get("TRAIN", "patience", fallback='1'))
decay_factor = float(config.get("TRAIN", "decay_factor", fallback='0.5'))
scheduler = ReduceLROnPlateau(optimizer,
factor=decay_factor,
patience=patience,
eps=eps_tolerance)
update_lr_per_step = False
elif scheduler_type == "one_cycle":
max_train_steps = int(config.get("TRAIN", "max_train_steps"))
anneal_strategy = config.get("TRAIN", "anneal_strategy", fallback="cos")
pct_start = float(config.get("TRAIN", "pct_start", fallback=0.3))
scheduler = OneCycleLR(optimizer,
max_lr=lr,
total_steps=max_train_steps,
anneal_strategy=anneal_strategy,
pct_start=pct_start)
update_lr_per_step = True
else:
raise Exception(f"Invalid scheduler type: {scheduler_type}")
return (update_lr_per_step, scheduler)