in common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]
def __init__(self,
config,
experiment_id,
training_workflow_metadata={},
hosting_workflow_metadata={},
joining_workflow_metadata={},
evaluation_workflow_metadata={}
):
"""Initialize/Reload an experiment entity to manage the workflow
Args:
config (dict): Config values for the experiment setup
experiment_id (str): A unique experiment id for the experiment
training_workflow_metadata (dict): Metadata for the training workflow
hosting_workflow_metadata (dict): Metadata for the hosting workflow
joining_workflow_metadata (dict): Metadata for the joining workflow
evaluation_workflow_metadata (dict): Metadata for the evaluation workflow
Return:
sagemaker_rl.orchestrator.workflow.experiment_manager.ExperimentManager: A ``ExperimentManager`` object
to manage the workflow
"""
self.boto_session = boto3.Session()
self._region_name = self.boto_session.region_name
self.account = self.boto_session.client("sts").get_caller_identity()["Account"]
if self._region_name is None:
raise ValueError('Must setup AWS configuration with a valid region')
# unique id common across all experiments in the account
self.experiment_id = experiment_id
# load configs
self.config = config
self.image = self.config.get("image", None).replace("{AWS_REGION}", self._region_name)
self.algor_config = self.config.get("algor", {})
self.local_mode = self.config.get("local_mode", True)
if self.local_mode:
self._update_instance_type_for_local_mode()
self.sagemaker_session = LocalSession()
else:
self.sagemaker_session = sagemaker.session.Session(self.boto_session)
self.soft_deployment = self.config.get("soft_deployment", False)
# load resource config and init shared resourced if not exists
self.resource_manager = ResourceManager(self.config.get("resource", {}),
boto_session=self.boto_session)
self.resource_manager.create_shared_resource_if_not_exist()
# init clients
self.exp_db_client = self.resource_manager.exp_db_client
self.model_db_client = self.resource_manager.model_db_client
self.join_db_client = self.resource_manager.join_db_client
self.cw_logger = CloudWatchLogger(
self.boto_session.client("cloudwatch"),
self._region_name
)
self.sagemaker_client = self.sagemaker_session.sagemaker_client
# init s3 client for rewards upload
self.s3_client = self.boto_session.client('s3')
# create a local JoinJobRecord object.
self.experiment_record = ExperimentRecord(
experiment_id,
training_workflow_metadata,
hosting_workflow_metadata,
joining_workflow_metadata,
evaluation_workflow_metadata
)
self.next_model_to_train = None
self.next_join_job = None
self.next_model_to_evaluate = None
# Try to save new ExperimentRecord to ExperimentDb. If it throws
# RecordAlreadyExistsException, re-read the ExperimentRecord from ExperimentDb,
# and use it as initial state
try:
self.exp_db_client.create_new_experiment_record(
self.experiment_record.to_ddb_record()
)
except RecordAlreadyExistsException:
logger.warn(f"Experiment with name {self.experiment_id} already exists. "
"Reusing current state from ExperimentDb.")
experiment_record = self.exp_db_client.get_experiment_record(
experiment_id
)
self.experiment_record = ExperimentRecord.load_from_ddb_record(experiment_record)
except Exception as e:
logger.error("Unhandled Exception! " + str(e))
raise UnhandledWorkflowException("Something went wrong while creating a new experiment")
try:
self.cw_logger.create_cloudwatch_dashboard_from_experiment_id(
self.experiment_id
)
except Exception as e:
logger.error("Unable to create CloudWatch Dashboard." + str(e))
logger.error("To see metrics on CloudWatch, run bandit_experiment."
"cw_logger.create_cloudwatch_dashboard_from_experiment_id function again.")
# start a daemon thread to sync ExperimentDb states to local states
# the daemon thread will keep running till the session ends
self.sync_thread = ExperimentManagerSyncThread(experiment_manager=self)
# Run the thread in SageMaker mode only
if not self.local_mode:
self.sync_thread.setDaemon(True)
self.sync_thread.start()