def copy_best_frozen_graph_to_sm_output_dir()

in src/markov/boto/s3/files/checkpoint_files/tensorflow_model.py [0:0]


    def copy_best_frozen_graph_to_sm_output_dir(self, best_checkpoint_number, last_checkpoint_number,
                                                source_dir, dest_dir):
        """Copy the frozen model for the current best checkpoint from soure directory to the destination directory.

        Args:
            s3_bucket (str): S3 bucket where the deepracer_checkpoints.json is stored
            s3_prefix (str): S3 prefix where the deepracer_checkpoints.json is stored
            region (str): AWS region where the deepracer_checkpoints.json is stored
            source_dir (str): Source directory where the frozen models are present
            dest_dir (str): Sagemaker output directory where we store the frozen models for best checkpoint
        """
        dest_dir_pb_files = [filename for filename in os.listdir(dest_dir)
                             if os.path.isfile(os.path.join(dest_dir, filename)) and filename.endswith(".pb")]
        source_dir_pb_files = [filename for filename in os.listdir(source_dir)
                               if os.path.isfile(os.path.join(source_dir, filename)) and filename.endswith(".pb")]

        LOG.info("Best checkpoint number: {}, Last checkpoint number: {}".
                 format(best_checkpoint_number, last_checkpoint_number))
        best_model_name = 'model_{}.pb'.format(best_checkpoint_number)
        last_model_name = 'model_{}.pb'.format(last_checkpoint_number)
        if len(source_dir_pb_files) < 1:
            log_and_exit("Could not find any frozen model file in the local directory",
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_500)
        try:
            # Could not find the deepracer_checkpoints.json file or there are no model.pb files in destination
            if best_checkpoint_number == -1 or len(dest_dir_pb_files) == 0:
                if len(source_dir_pb_files) > 1:
                    LOG.info("More than one model.pb found in the source directory. Choosing the "
                             "first one to copy to destination: {}".format(source_dir_pb_files[0]))
                # copy the frozen model present in the source directory
                LOG.info("Copying the frozen checkpoint from {} to {}.".format(
                    os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb")))
                shutil.copy(os.path.join(source_dir, source_dir_pb_files[0]), os.path.join(dest_dir, "model.pb"))
            else:
                # Delete the current .pb files in the destination direcory
                for filename in dest_dir_pb_files:
                    os.remove(os.path.join(dest_dir, filename))

                # Copy the frozen model for the current best checkpoint to the destination directory
                LOG.info("Copying the frozen checkpoint from {} to {}.".format(
                    os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb")))
                shutil.copy(os.path.join(source_dir, best_model_name), os.path.join(dest_dir, "model.pb"))

                # Loop through the current list of frozen models in source directory and
                # delete the iterations lower than last_checkpoint_iteration except best_model
                for filename in source_dir_pb_files:
                    if filename not in [best_model_name, last_model_name]:
                        if len(filename.split("_")[1]) > 1 and len(filename.split("_")[1].split(".pb")):
                            file_iteration = int(filename.split("_")[1].split(".pb")[0])
                            if file_iteration < last_checkpoint_number:
                                os.remove(os.path.join(source_dir, filename))
                        else:
                            LOG.error("Frozen model name not in the right format in the source directory: {}, {}".
                                      format(filename, source_dir))
        except FileNotFoundError as err:
            log_and_exit("No such file or directory: {}".format(err),
                         SIMAPP_S3_DATA_STORE_EXCEPTION,
                         SIMAPP_EVENT_ERROR_CODE_400)