in archived/rl_gamerserver_ray/common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]
def join(self, rewards_s3_path, obs_time_window=None, ratio=0.8, wait=True):
"""Start a joining job given rewards data path and observation
data time window
Args:
rewards_s3_path (str): S3 data path containing the rewards data
obs_time_window (int): Define a time window of past X hours to
select observation data
ratio (float): Split ratio used to split training data
and evaluation data
wait (bool): Whether to wait until the joining job finish
"""
# Sync experiment state if required
self._sync_experiment_state_with_ddb()
if obs_time_window is None:
logger.warning(
f"Start a join job to join reward data "
f"under '{rewards_s3_path}' with all the observation data"
)
obs_end_time = None
obs_start_time = None
else:
logger.info(
f"Start a join job to join reward data "
f"under '{rewards_s3_path}' with observation "
f"data in the past {obs_time_window} hours"
)
obs_end_time = datetime.utcnow()
obs_start_time = obs_end_time - timedelta(hours=obs_time_window)
# update next_join_job_id and joining state
next_join_job_id = JoinManager.name_next_join_job(experiment_id=self.experiment_id)
self.exp_db_client.update_experiment_next_join_job_id(self.experiment_id, next_join_job_id)
self.exp_db_client.update_experiment_joining_state(self.experiment_id, JoiningState.PENDING)
input_obs_data_s3_path = (
f"s3://{self.resource_manager.firehose_bucket}/{self.experiment_id}"
)
input_obs_data_s3_path = f"{input_obs_data_s3_path}/inference_data"
# init joining job, update join table
logger.info("Creating resource for joining job...")
try:
self.next_join_job = JoinManager(
join_db_client=self.join_db_client,
experiment_id=self.experiment_id,
join_job_id=next_join_job_id,
input_obs_data_s3_path=input_obs_data_s3_path,
obs_start_time=obs_start_time,
obs_end_time=obs_end_time,
input_reward_data_s3_path=rewards_s3_path,
boto_session=self.boto_session,
)
logger.info("Started joining job...")
self.next_join_job.start_join(ratio=ratio, wait=wait)
except Exception as e:
logger.error(e)
pass
# wait until exp ddb table updated
if self.local_mode or wait:
succeeded_state = (
self.experiment_record._joining_state == JoiningState.SUCCEEDED
and self.experiment_record._last_joined_job_id == next_join_job_id
and self.experiment_record._next_join_job_id is None
)
num_retries = 0
while not succeeded_state:
# Sync experiment state if required
self._sync_experiment_state_with_ddb()
logger.debug("Waiting for experiment table joining status to be updated...")
time.sleep(2 * (2 ** num_retries))
succeeded_state = (
self.experiment_record._joining_state == JoiningState.SUCCEEDED
and self.experiment_record._last_joined_job_id == next_join_job_id
and self.experiment_record._next_join_job_id is None
)
num_retries += 1
if num_retries >= 5:
raise UnhandledWorkflowException(
f"Joining job '{self.experiment_record._next_join_job_id}' "
f"was in state of '{self.experiment_record._joining_state}'. Failed to sync table states."
)
if (
self.experiment_record._joining_state == JoiningState.FAILED
or self.experiment_record._joining_state == JoiningState.CANCELLED
):
raise WorkflowJoiningJobException(
f"Joining job '{self.experiment_record._next_join_job_id}' "
f"ended with state '{self.experiment_record._joining_state}'. Please check Athena queries logs "
"for more information."
)