in source/RLlibEnv/training/common/sagemaker_rl/orchestrator/workflow/manager/join_manager.py [0:0]
def __init__(
self,
join_db_client: JoinDbClient,
experiment_id,
join_job_id,
current_state=None,
input_obs_data_s3_path=None,
obs_start_time=None,
obs_end_time=None,
input_reward_data_s3_path=None,
output_joined_train_data_s3_path=None,
output_joined_eval_data_s3_path=None,
join_query_ids=[],
boto_session=None):
"""Initialize a joining job entity in the current experiment
Args:
join_db_client (JoinDbClient): A DynamoDB client
to query the joining job table. The 'JoinJob' entity use this client
to read/update the job state.
experiment_id (str): A unique id for the experiment. The created/loaded
joining job will be associated with the experiment.
join_job_id (str): Aa unique id for the join job. The join job table uses
join_job_id to manage associated job metadata.
current_state (str): Current state of the joining job
input_obs_data_s3_path (str): Input S3 data path for observation data
obs_start_time (datetime): Datetime object to specify starting time of the
observation data
obs_end_time (datetime): Datetime object to specify ending time of the
observation data
input_reward_data_s3_path (str): S3 data path for rewards data
output_joined_train_data_s3_path (str): Output S3 data path for training data split
output_joined_eval_data_s3_path (str): Output S3 data path for evaluation data split
join_query_ids (str): Athena join query ids for the joining requests
boto_session (boto3.session.Session): A session stores configuration
state and allows you to create service clients and resources.
Return:
orchestrator.join_manager.JoinManager: A ``JoinJob`` object associated
with the given experiment.
"""
self.join_db_client = join_db_client
self.experiment_id = experiment_id
self.join_job_id = join_job_id
if boto_session is None:
boto_session = boto3.Session()
self.boto_session = boto_session
# formatted athena table name
self.obs_table_partitioned = self._formatted_table_name(f"obs-{experiment_id}-partitioned")
self.obs_table_non_partitioned = self._formatted_table_name(f"obs-{experiment_id}")
self.rewards_table = self._formatted_table_name(f"rewards-{experiment_id}")
self.query_s3_output_bucket = self._create_athena_s3_bucket_if_not_exist()
self.athena_client = self.boto_session.client("athena")
# create a local JoinJobRecord object.
self.join_job_record = JoinJobRecord(
experiment_id,
join_job_id,
current_state,
input_obs_data_s3_path,
obs_start_time,
obs_end_time,
input_reward_data_s3_path,
output_joined_train_data_s3_path,
output_joined_eval_data_s3_path,
join_query_ids
)
# create obs partitioned/non-partitioned table if not exists
if input_obs_data_s3_path and input_obs_data_s3_path != "local-join-does-not-apply":
self._create_obs_table_if_not_exist()
# create reward table if not exists
if input_reward_data_s3_path and input_reward_data_s3_path != "local-join-does-not-apply":
self._create_rewards_table_if_not_exist()
# add partitions if input_obs_time_window is not None
if obs_start_time and obs_end_time:
self._add_time_partitions(obs_start_time, obs_end_time)
# try to save this record file. if it throws RecordAlreadyExistsException
# reload the record from JoinJobDb, and recreate
try:
self.join_db_client.create_new_join_job_record(
self.join_job_record.to_ddb_record()
)
except RecordAlreadyExistsException:
logger.debug("Join job already exists. Reloading from join job record.")
join_job_record = self.join_db_client.get_join_job_record(
experiment_id,
join_job_id
)
self.join_job_record = JoinJobRecord.load_from_ddb_record(join_job_record)
except Exception as e:
logger.error("Unhandled Exception! " + str(e))
raise UnhandledWorkflowException("Something went wrong while creating a new join job")