in src/util.py [0:0]
def set_optim(opt, model):
if opt.optim == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
elif opt.optim == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
if opt.scheduler == 'fixed':
scheduler = FixedScheduler(optimizer)
elif opt.scheduler == 'linear':
if opt.scheduler_steps is None:
scheduler_steps = opt.total_steps
else:
scheduler_steps = opt.scheduler_steps
scheduler = WarmupLinearScheduler(optimizer, warmup_steps=opt.warmup_steps, scheduler_steps=scheduler_steps, min_ratio=0., fixed_lr=opt.fixed_lr)
return optimizer, scheduler