def get()

in src/markov/boto/s3/files/checkpoint_files/tensorflow_model.py [0:0]


    def get(self, coach_checkpoint_state_file):
        '''get tensorflow model specified in the rl coach checkpoint state file
        If the rl coach checkpoint state file specified checkpoint is missing. It will
        download last checkpoints and over write the last in local rl coach checkpoint state file

        Args:
            coach_checkpoint_state_file (CheckpointStateFile): CheckpointStateFile instance
        '''
        has_checkpoint = False
        last_checkpoint_number = -1
        last_checkpoint_name = None
        # list everything in tensorflow model s3 bucket dir
        # to find the checkpoint specified in .coach_checkpoint
        # or use the last
        checkpoint_name = str(coach_checkpoint_state_file.read())
        for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir):
            if "Contents" in page:
                # Check to see if the desired tensorflow model is in the bucket
                # for example if obj is (dir)/487_Step-2477372.ckpt.data-00000-of-00001
                # curr_checkpoint_number: 487
                # curr_checkpoint_name: 487_Step-2477372.ckpt.data-00000-of-00001
                for obj in page['Contents']:
                    curr_checkpoint_name = os.path.split(obj['Key'])[1]
                    # if found the checkpoint name stored in .coach_checkpoint file
                    # break inner loop for file search
                    if curr_checkpoint_name.startswith(checkpoint_name):
                        has_checkpoint = True
                        break
                    # if the file name does not start with a number (not ckpt file)
                    # continue for next file
                    if not utils.is_int_repr(curr_checkpoint_name.split("_")[0]):
                        continue
                    # if the file name start with a number, update the last checkpoint name
                    # and number
                    curr_checkpoint_number = int(curr_checkpoint_name.split("_")[0])
                    if curr_checkpoint_number > last_checkpoint_number:
                        last_checkpoint_number = curr_checkpoint_number
                        last_checkpoint_name = curr_checkpoint_name.rsplit('.', 1)[0]
            # break out from pagination if find the checkpoint
            if has_checkpoint:
                break

        # update checkpoint_name to the last_checkpoint_name and overwrite local
        # .coach_checkpoint file to contain the last checkpoint
        if not has_checkpoint:
            if last_checkpoint_name:
                coach_checkpoint_state_file.write(SingleCheckpoint(
                    num=last_checkpoint_number,
                    name=last_checkpoint_name))
                LOG.info("%s not in s3 bucket, downloading %s checkpoints", checkpoint_name, last_checkpoint_name)
                checkpoint_name = last_checkpoint_name
            else:
                log_and_exit("No checkpoint files",
                             SIMAPP_S3_DATA_STORE_EXCEPTION,
                             SIMAPP_EVENT_ERROR_CODE_400)

        # download the desired checkpoint file
        for page in self._s3_client.paginate(bucket=self._bucket, prefix=self._s3_key_dir):
            if "Contents" in page:
                for obj in page['Contents']:
                    s3_key = obj["Key"]
                    _, file_name = os.path.split(s3_key)
                    local_path = os.path.normpath(os.path.join(self._local_dir,
                                                               file_name))
                    _, file_extension = os.path.splitext(s3_key)
                    if file_extension != '.pb' and file_name.startswith(checkpoint_name):
                        self._download(s3_key=s3_key, local_path=local_path)