in pytorch_translate/train.py [0:0]
def setup_training_state(args, trainer, task, epoch_itr):
"""Set up the directory for saving checkpoints.
Load pretrained model if specified."""
PathManager.mkdirs(args.save_dir)
# If --restore-file is already present under --save-dir, use that one
# instead of --pretrained-checkpoint-file. The idea is that
# --pretrained-checkpoint-file allows the user to specify restoring from a
# different run's checkpoint (possibly with different training params),
# while not polluting the previous run's checkpoint directory
# with new checkpoints. However, if training gets interrupted
# and the user restarts training, we want to resume from
# the checkpoints under --save-dir, instead of
# restarting again from the old run's checkpoint at
# --pretrained-checkpoint-file.
#
# Note that if args.restore_file is an absolute path, os.path.join() will
# ignore previous directory args and just use the absolute path as is.
checkpoint_path = os.path.join(args.save_dir, args.restore_file)
restore_state = True
if PathManager.isfile(checkpoint_path):
print(
f"| Using --save-dir={args.save_dir}, --restore-file={args.restore_file}."
)
elif args.pretrained_checkpoint_file and PathManager.isfile(
args.pretrained_checkpoint_file
):
checkpoint_path = args.pretrained_checkpoint_file
restore_state = args.load_pretrained_checkpoint_state
print(
f"| Using --pretrained-checkpoint-file={args.pretrained_checkpoint_file}, "
f"--load-pretrained-checkpoint-state={args.load_pretrained_checkpoint_state}."
)
extra_state = default_extra_state(args)
if not PathManager.isfile(checkpoint_path) and args.multi_model_restore_files:
print(f"| Restoring individual models from {args.multi_model_restore_files}")
multi_model.import_individual_models(args.multi_model_restore_files, trainer)
else:
loaded, loaded_extra_state = checkpoint.load_existing_checkpoint(
checkpoint_path=checkpoint_path,
trainer=trainer,
restore_state=restore_state,
)
if loaded_extra_state:
extra_state.update(loaded_extra_state)
# Reset the start time for the current training run.
extra_state["start_time"] = time.time()
# Skips printing all training progress to prevent log spam.
training_progress = extra_state["training_progress"]
extra_state["training_progress"] = (
["...truncated...", training_progress[-1]] if len(training_progress) > 0 else []
)
print(f"| extra_state: {extra_state}")
extra_state["training_progress"] = training_progress
epoch = extra_state["epoch"]
if extra_state["batch_offset"] == 0:
epoch -= 1 # this will be incremented when we call epoch_itr.next_epoch_itr()
epoch_itr.load_state_dict(
{"epoch": epoch, "iterations_in_epoch": extra_state["batch_offset"]}
)
checkpoint_manager = None
if distributed_utils.is_master(args):
checkpoint_manager = checkpoint.CheckpointManager(
num_avg_checkpoints=args.num_avg_checkpoints,
auto_clear_checkpoints=args.auto_clear_checkpoints,
log_verbose=args.log_verbose,
checkpoint_files=extra_state["checkpoint_files"],
)
return extra_state, epoch_itr, checkpoint_manager