in train_helpers.py [0:0]
def load_opt(H, vae, logprint):
optimizer = AdamW(vae.parameters(), weight_decay=H.wd, lr=H.lr, betas=(H.adam_beta1, H.adam_beta2))
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=linear_warmup(H.warmup_iters))
if H.restore_optimizer_path:
optimizer.load_state_dict(
torch.load(distributed_maybe_download(H.restore_optimizer_path, H.local_rank, H.mpi_size), map_location='cpu'))
if H.restore_log_path:
cur_eval_loss, iterate, starting_epoch = restore_log(H.restore_log_path, H.local_rank, H.mpi_size)
else:
cur_eval_loss, iterate, starting_epoch = float('inf'), 0, 0
logprint('starting at epoch', starting_epoch, 'iterate', iterate, 'eval loss', cur_eval_loss)
return optimizer, scheduler, cur_eval_loss, iterate, starting_epoch