in src/markov/boto/s3/files/checkpoint.py [0:0]
def __init__(self, bucket, s3_prefix, region_name="us-east-1",
agent_name='agent', checkpoint_dir="./checkpoint",
max_retry_attempts=5, backoff_time_sec=1.0,
output_head_format='main_level/{}/main/online/network_1/ppo_head_0/policy',
log_and_cont: bool = False):
'''This class is a placeholder for RLCoachCheckpoint, DeepracerCheckpointJson,
RlCoachSyncFile, TensorflowModel to handle all checkpoint related logic
Args:
bucket (str): S3 bucket string.
s3_prefix (str): S3 prefix string.
region_name (str): S3 region name.
Defaults to 'us-east-1'.
agent_name (str): Agent name.
Defaults to 'agent'.
checkpoint_dir (str, optional): Local file directory.
Defaults to './checkpoint'.
max_retry_attempts (int, optional): Maximum number of retry attempts for S3 download/upload.
Defaults to 5.
backoff_time_sec (float, optional): Backoff second between each retry.
Defaults to 1.0.
output_head_format (str): output head format for the specific algorithm and action space
which will be used to store the frozen graph
log_and_cont (bool, optional): Log the error and continue with the flow.
Defaults to False.
'''
if not bucket or not s3_prefix:
log_and_exit("checkpoint S3 prefix or bucket not available for S3. \
bucket: {}, prefix {}"
.format(bucket, s3_prefix),
SIMAPP_SIMULATION_WORKER_EXCEPTION,
SIMAPP_EVENT_ERROR_CODE_500)
self._agent_name = agent_name
self._s3_dir = os.path.normpath(os.path.join(s3_prefix,
CHECKPOINT_POSTFIX_DIR))
# rl coach checkpoint
self._rl_coach_checkpoint = RLCoachCheckpoint(bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=os.path.join(checkpoint_dir,
agent_name),
max_retry_attempts=max_retry_attempts,
backoff_time_sec=backoff_time_sec,
log_and_cont=log_and_cont)
# deepracer checkpoint json
# do not retry on deepracer checkpoint because initially
# it can do not exist.
self._deepracer_checkpoint_json = \
DeepracerCheckpointJson(bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=os.path.join(checkpoint_dir, agent_name),
max_retry_attempts=0,
backoff_time_sec=backoff_time_sec,
log_and_cont=log_and_cont)
# rl coach .finished
self._syncfile_finished = RlCoachSyncFile(syncfile_type=SyncFiles.FINISHED.value,
bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=os.path.join(checkpoint_dir,
agent_name),
max_retry_attempts=max_retry_attempts,
backoff_time_sec=backoff_time_sec)
# rl coach .lock: global lock for all agent located at checkpoint directory
self._syncfile_lock = RlCoachSyncFile(syncfile_type=SyncFiles.LOCKFILE.value,
bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=checkpoint_dir,
max_retry_attempts=max_retry_attempts,
backoff_time_sec=backoff_time_sec)
# rl coach .ready
self._syncfile_ready = RlCoachSyncFile(syncfile_type=SyncFiles.TRAINER_READY.value,
bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=os.path.join(checkpoint_dir,
agent_name),
max_retry_attempts=max_retry_attempts,
backoff_time_sec=backoff_time_sec)
# tensorflow .ckpt files
self._tensorflow_model = TensorflowModel(bucket=bucket,
s3_prefix=s3_prefix,
region_name=region_name,
local_dir=os.path.join(checkpoint_dir,
agent_name),
max_retry_attempts=max_retry_attempts,
backoff_time_sec=backoff_time_sec,
output_head_format=output_head_format)