def load_checkpoint()

in bring-your-own-container/fairseq_translation/fairseq/train_driver.py [0:0]


def load_checkpoint(args, trainer, epoch_itr):
    """Load a checkpoint and replay dataloader to match."""
    os.makedirs(args.save_dir, exist_ok=True)
    checkpoint_path = os.path.join(args.save_dir, args.restore_file)
    if os.path.isfile(checkpoint_path):
        extra_state = trainer.load_checkpoint(
            checkpoint_path,
            args.reset_optimizer,
            args.reset_lr_scheduler,
            eval(args.optimizer_overrides),
        )
        if extra_state is not None:
            # replay train iterator to match checkpoint
            epoch_itr.load_state_dict(extra_state["train_iterator"])

            print(
                "| loaded checkpoint {} (epoch {} @ {} updates)".format(
                    checkpoint_path, epoch_itr.epoch, trainer.get_num_updates()
                )
            )

            trainer.lr_step(epoch_itr.epoch)
            trainer.lr_step_update(trainer.get_num_updates())
            if "best" in extra_state:
                save_checkpoint.best = extra_state["best"]
        return True
    return False