in common/sagemaker_rl/ray_launcher.py [0:0]
def set_up_checkpoint(self, config=None):
try:
checkpoint_dir = config["training"]["restore"]
print("Found checkpoint dir %s in user config." % checkpoint_dir)
return config
except KeyError:
pass
if not os.path.exists(CHECKPOINT_DIR):
print("No checkpoint path specified. Training from scratch.")
return config
checkpoint_dir = self._checkpoint_dir_finder(CHECKPOINT_DIR)
# validate the contents
print("checkpoint_dir is {}".format(checkpoint_dir))
checkpoint_dir_contents = os.listdir(checkpoint_dir)
if len(checkpoint_dir_contents) not in [2, 3]:
raise RuntimeError(
f"Unexpected files {checkpoint_dir_contents} in checkpoint dir. "
"Please check ray documents for the correct checkpoint format."
)
validation = 0
checkpoint_file_in_container = ""
for filename in checkpoint_dir_contents:
is_tune_metadata = filename.endswith("tune_metadata")
is_extra_data = filename.endswith("extra_data")
is_checkpoint_meta = is_tune_metadata + is_extra_data
validation += is_checkpoint_meta
if not is_checkpoint_meta:
checkpoint_file_in_container = os.path.join(checkpoint_dir, filename)
if ray.__version__ >= "0.6.5":
if validation is not 1:
raise RuntimeError("Failed to find .tune_metadata to restore checkpoint.")
else:
if validation is not 2:
raise RuntimeError(
"Failed to find .tune_metadata or .extra_data to restore checkpoint"
)
if checkpoint_file_in_container:
print(
"Found checkpoint: %s. Setting `restore` path in ray config."
% checkpoint_file_in_container
)
config["training"]["restore"] = checkpoint_file_in_container
else:
print("No valid checkpoint found in %s. Training from scratch." % checkpoint_dir)
return config