in src/markov/boto/s3/files/checkpoint_files/tensorflow_model.py [0:0]
def get(self, coach_checkpoint_state_file):
'''get tensorflow model specified in the rl coach checkpoint state file
If the rl coach checkpoint state file specified checkpoint is missing. It will
download last checkpoints and over write the last in local rl coach checkpoint state file
Args:
coach_checkpoint_state_file (CheckpointStateFile): CheckpointStateFile instance
'''
has_checkpoint = False
last_checkpoint_number = -1
last_checkpoint_name = None
# list everything in tensorflow model s3 bucket dir
# to find the checkpoint specified in .coach_checkpoint
# or use the last
checkpoint_name = str(coach_checkpoint_state_file.read())
for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir):
if "Contents" in page:
# Check to see if the desired tensorflow model is in the bucket
# for example if obj is (dir)/487_Step-2477372.ckpt.data-00000-of-00001
# curr_checkpoint_number: 487
# curr_checkpoint_name: 487_Step-2477372.ckpt.data-00000-of-00001
for obj in page['Contents']:
curr_checkpoint_name = os.path.split(obj['Key'])[1]
# if found the checkpoint name stored in .coach_checkpoint file
# break inner loop for file search
if curr_checkpoint_name.startswith(checkpoint_name):
has_checkpoint = True
break
# if the file name does not start with a number (not ckpt file)
# continue for next file
if not utils.is_int_repr(curr_checkpoint_name.split("_")[0]):
continue
# if the file name start with a number, update the last checkpoint name
# and number
curr_checkpoint_number = int(curr_checkpoint_name.split("_")[0])
if curr_checkpoint_number > last_checkpoint_number:
last_checkpoint_number = curr_checkpoint_number
last_checkpoint_name = curr_checkpoint_name.rsplit('.', 1)[0]
# break out from pagination if find the checkpoint
if has_checkpoint:
break
# update checkpoint_name to the last_checkpoint_name and overwrite local
# .coach_checkpoint file to contain the last checkpoint
if not has_checkpoint:
if last_checkpoint_name:
coach_checkpoint_state_file.write(SingleCheckpoint(
num=last_checkpoint_number,
name=last_checkpoint_name))
LOG.info("%s not in s3 bucket, downloading %s checkpoints", checkpoint_name, last_checkpoint_name)
checkpoint_name = last_checkpoint_name
else:
log_and_exit("No checkpoint files",
SIMAPP_S3_DATA_STORE_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_400)
# download the desired checkpoint file
for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir):
if "Contents" in page:
for obj in page['Contents']:
s3_key = obj["Key"]
_, file_name = os.path.split(s3_key)
local_path = os.path.normpath(os.path.join(self._local_dir,
file_name))
_, file_extension = os.path.splitext(s3_key)
if file_extension != '.pb' and file_name.startswith(checkpoint_name):
self._download(s3_key=s3_key, local_path=local_path)