in common/sagemaker_rl/ray_launcher.py [0:0]
def copy_checkpoints_to_model_output(self):
checkpoints = []
count = 0
while not checkpoints:
count += 1
for root, directories, filenames in os.walk(INTERMEDIATE_DIR):
for filename in filenames:
if filename.startswith("checkpoint"):
checkpoints.append(os.path.join(root, filename))
time.sleep(5)
if count >= 6:
raise RuntimeError("Failed to find checkpoint files")
checkpoints.sort(key=natural_keys)
latest_checkpoints = checkpoints[-2:]
validation = sum(
1 if x.endswith("tune_metadata") or x.endswith("extra_data") else 0
for x in latest_checkpoints
)
if ray.__version__ >= "0.6.5":
if validation is not 1:
raise RuntimeError("Failed to save checkpoint files - .tune_metadata")
else:
if validation is not 2:
raise RuntimeError(
"Failed to save checkpoint files - .tune_metadata or .extra_data"
)
for source_path in latest_checkpoints:
_, ext = os.path.splitext(source_path)
destination_path = os.path.join(MODEL_OUTPUT_DIR, "checkpoint%s" % ext)
copyfile(source_path, destination_path)
print("Saved the checkpoint file %s as %s" % (source_path, destination_path))