in bring-your-own-container/fairseq_translation/fairseq/train_driver.py [0:0]
def load_checkpoint(args, trainer, epoch_itr):
"""Load a checkpoint and replay dataloader to match."""
os.makedirs(args.save_dir, exist_ok=True)
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
if os.path.isfile(checkpoint_path):
extra_state = trainer.load_checkpoint(
checkpoint_path,
args.reset_optimizer,
args.reset_lr_scheduler,
eval(args.optimizer_overrides),
)
if extra_state is not None:
# replay train iterator to match checkpoint
epoch_itr.load_state_dict(extra_state["train_iterator"])
print(
"| loaded checkpoint {} (epoch {} @ {} updates)".format(
checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()
)
)
trainer.lr_step(epoch_itr.epoch)
trainer.lr_step_update(trainer.get_num_updates())
if "best" in extra_state:
save_checkpoint.best = extra_state["best"]
return True
return False