in src/markov/boto/s3/files/checkpoint_files/tensorflow_model.py [0:0]
def copy_best_frozen_graph_to_sm_output_dir(self, best_checkpoint_number, last_checkpoint_number,
source_dir, dest_dir):
"""Copy the frozen model for the current best checkpoint from soure directory to the destination directory.
Args:
s3_bucket (str): S3 bucket where the deepracer_checkpoints.json is stored
s3_prefix (str): S3 prefix where the deepracer_checkpoints.json is stored
region (str): AWS region where the deepracer_checkpoints.json is stored
source_dir (str): Source directory where the frozen models are present
dest_dir (str): Sagemaker output directory where we store the frozen models for best checkpoint
"""
dest_dir_pb_files = [filename for filename in os.listdir(dest_dir)
if os.path.isfile(os.path.join(dest_dir, filename)) and filename.endswith(".pb")]
source_dir_pb_files = [filename for filename in os.listdir(source_dir)
if os.path.isfile(os.path.join(source_dir, filename)) and filename.endswith(".pb")]
LOG.info("Best checkpoint number: {}, Last checkpoint number: {}".
format(best_checkpoint_number, last_checkpoint_number))
best_model_name = 'model_{}.pb'.format(best_checkpoint_number)
last_model_name = 'model_{}.pb'.format(last_checkpoint_number)
if len(source_dir_pb_files) < 1:
log_and_exit("Could not find any frozen model file in the local directory",
SIMAPP_S3_DATA_STORE_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
try:
# Could not find the deepracer_checkpoints.json file or there are no model.pb files in destination
if best_checkpoint_number == -1 or len(dest_dir_pb_files) == 0:
if len(source_dir_pb_files) > 1:
LOG.info("More than one model.pb found in the source directory. Choosing the "
"first one to copy to destination: {}".format(source_dir_pb_files[0]))
# copy the frozen model present in the source directory
LOG.info("Copying the frozen checkpoint from {} to {}.".format(
os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb")))
shutil.copy(os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb"))
else:
# Delete the current .pb files in the destination direcory
for filename in dest_dir_pb_files:
os.remove(os.path.join(dest_dir, filename))
# Copy the frozen model for the current best checkpoint to the destination directory
LOG.info("Copying the frozen checkpoint from {} to {}.".format(
os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb")))
shutil.copy(os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb"))
# Loop through the current list of frozen models in source directory and
# delete the iterations lower than last_checkpoint_iteration except best_model
for filename in source_dir_pb_files:
if filename not in [best_model_name, last_model_name]:
if len(filename.split("_")[1]) > 1 and len(filename.split("_")[1].split(".pb")):
file_iteration = int(filename.split("_")[1].split(".pb")[0])
if file_iteration < last_checkpoint_number:
os.remove(os.path.join(source_dir, filename))
else:
LOG.error("Frozen model name not in the right format in the source directory: {}, {}".
format(filename, source_dir))
except FileNotFoundError as err:
log_and_exit("No such file or directory: {}".format(err),
SIMAPP_S3_DATA_STORE_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_400)