in robot_ws/src/rl_agent/markov/s3_boto_data_store.py [0:0]
def load_from_store(self, expected_checkpoint_number=-1):
try:
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir, CHECKPOINT_METADATA_FILENAME))
if not os.path.exists(self.params.checkpoint_dir):
os.makedirs(self.params.checkpoint_dir)
while True:
s3_client = self._get_client()
response = s3_client.list_objects_v2(Bucket=self.params.bucket,
Prefix=self._get_s3_key(self.params.lock_file))
if "Contents" not in response:
try:
# If no lock is found, try getting the checkpoint
s3_client.download_file(Bucket=self.params.bucket,
Key=self._get_s3_key(CHECKPOINT_METADATA_FILENAME),
Filename=filename)
except Exception as e:
print("Got exception while downloading checkpoint", e)
time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
continue
else:
time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
continue
checkpoint = self._get_current_checkpoint()
if checkpoint:
checkpoint_number = self._get_checkpoint_number(checkpoint)
# if we get a checkpoint that is older that the expected checkpoint, we wait for
# the new checkpoint to arrive.
if checkpoint_number < expected_checkpoint_number:
time.sleep(SLEEP_TIME_WHILE_WAITING_FOR_DATA_FROM_TRAINER_IN_SECOND)
continue
# Found a checkpoint to be downloaded
response = s3_client.list_objects_v2(Bucket=self.params.bucket,
Prefix=self._get_s3_key(checkpoint.model_checkpoint_path))
if "Contents" in response:
num_files = 0
for obj in response["Contents"]:
# Get the local filename of the checkpoint file
filename = os.path.abspath(os.path.join(self.params.checkpoint_dir,
obj["Key"].replace(self.key_prefix, "")))
s3_client.download_file(Bucket=self.params.bucket,
Key=obj["Key"],
Filename=filename)
num_files += 1
print("Downloaded %s model files from S3" % num_files)
return True
except Exception as e:
print("Got exception while loading model from S3", e)
raise e