in robot_ws/src/rl_agent/markov/s3_boto_data_store.py [0:0]
def save_to_store(self):
try:
s3_client = self._get_client()
if self.graph_manager:
utils.write_frozen_graph(self.graph_manager, self.params.checkpoint_dir)
# Delete any existing lock file
s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))
# We take a lock by writing a lock file to the same location in S3
s3_client.upload_fileobj(Fileobj=io.BytesIO(b''),
Bucket=self.params.bucket,
Key=self._get_s3_key(self.params.lock_file))
# Start writing the model checkpoints to S3
checkpoint_file = None
for root, dirs, files in os.walk(self.params.checkpoint_dir):
for filename in files:
# Skip the checkpoint file that has the latest checkpoint number
if filename == CHECKPOINT_METADATA_FILENAME:
checkpoint_file = (root, filename)
continue
# Upload all the other files from the checkpoint directory
abs_name = os.path.abspath(os.path.join(root, filename))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
s3_client.upload_file(Filename=abs_name,
Bucket=self.params.bucket,
Key=self._get_s3_key(rel_name))
# After all the checkpoint files have been uploaded, we upload the version file.
abs_name = os.path.abspath(os.path.join(checkpoint_file[0], checkpoint_file[1]))
rel_name = os.path.relpath(abs_name, self.params.checkpoint_dir)
s3_client.upload_file(Filename=abs_name,
Bucket=self.params.bucket,
Key=self._get_s3_key(rel_name))
# Release the lock by deleting the lock file from S3
s3_client.delete_object(Bucket=self.params.bucket, Key=self._get_s3_key(self.params.lock_file))
checkpoint = self._get_current_checkpoint()
if checkpoint:
checkpoint_number = self._get_checkpoint_number(checkpoint)
checkpoint_number_to_delete = checkpoint_number - 4
# List all the old checkpoint files that needs to be deleted
response = s3_client.list_objects_v2(Bucket=self.params.bucket,
Prefix=self._get_s3_key(str(checkpoint_number_to_delete) + "_"))
if "Contents" in response:
num_files = 0
for obj in response["Contents"]:
s3_client.delete_object(Bucket=self.params.bucket,
Key=obj["Key"])
num_files += 1
print("Deleted %s model files from S3" % num_files)
return True
except Exception as e:
raise e