in common/sagemaker_rl/orchestrator/workflow/manager/model_manager.py [0:0]
def __init__(
self,
model_db_client: ModelDbClient,
experiment_id,
model_id,
image=None,
role=None,
instance_config={},
boto_session=None,
algor_config={},
train_state=None,
evaluation_job_name=None,
eval_state=None,
eval_scores={},
input_model_id=None,
rl_estimator=None,
input_data_s3_prefix=None,
manifest_file_path=None,
eval_data_s3_path=None,
s3_model_output_path=None,
training_start_time=None,
training_end_time=None):
"""Initialize a model entity in the current experiment
Args:
model_db_client (ModelDBClient): A DynamoDB client
to query the model table. The 'Model' entity use this client
to read/update the model state.
experiment_id (str): A unique id for the experiment. The created/loaded
model will be associated with the given experiment.
model_id (str): Aa unique id for the model. The model table uses
model id to manage associated model metadata.
image (str): The container image to use for training/evaluation.
role (str): An AWS IAM role (either name or full ARN). The Amazon
SageMaker training jobs will use this role to access AWS resources.
instance_config (dict): A dictionary that specify the resource
configuration for the model training/evaluation job.
boto_session (boto3.session.Session): A session stores configuration
state and allows you to create service clients and resources.
algor_config (dict): A dictionary that specify the algorithm type
and hyper parameters of the training/evaluation job.
train_state (str): State of the model training job.
evaluation_job_name (str): Job name for Latest Evaluation Job for this model
eval_state (str): State of the model evaluation job.
input_model_id (str): A unique model id to specify which model to use
as a pre-trained model for the model training job.
rl_estimator (sagemaker.rl.estimator.RLEstimator): A Sagemaker RLEstimator
entity that handle Reinforcement Learning (RL) execution within
a SageMaker Training Job.
input_data_s3_prefix (str): Input data path for the data source of the
model training job.
s3_model_output_path (str): Output data path of model artifact for the
model training job.
training_start_time (str): Starting timestamp of the model training job.
training_end_time (str): Finished timestamp of the model training job.
Returns:
orchestrator.model_manager.ModelManager: A ``Model`` object associated
with the given experiment.
"""
self.model_db_client = model_db_client
self.experiment_id = experiment_id
self.model_id = model_id
# Currently we are not storing image/role and other model params in ModelDb
self.image = image
self.role = role
self.instance_config = instance_config
self.algor_config = algor_config
# load configs
self.instance_type = self.instance_config.get("instance_type", "local")
self.instance_count = self.instance_config.get("instance_count", 1)
self.algor_params = self.algor_config.get("algorithms_parameters", {})
# create a local ModelRecord object.
self.model_record = ModelRecord(
experiment_id,
model_id,
train_state,
evaluation_job_name,
eval_state,
eval_scores,
input_model_id,
input_data_s3_prefix,
manifest_file_path,
eval_data_s3_path,
s3_model_output_path,
training_start_time,
training_end_time
)
# try to save this record file. if it throws RecordAlreadyExistsException
# reload the record from ModelDb, and recreate
try:
self.model_db_client.create_new_model_record(
self.model_record.to_ddb_record()
)
except RecordAlreadyExistsException:
logger.debug("Model already exists. Reloading from model record.")
model_record = self.model_db_client.get_model_record(
experiment_id,
model_id
)
self.model_record = ModelRecord.load_from_ddb_record(model_record)
except Exception as e:
logger.error("Unhandled Exception! " + str(e))
raise UnhandledWorkflowException("Something went wrong while creating a new model")
if boto_session is None:
boto_session = boto3.Session()
self.boto_session = boto_session
if self.instance_type == 'local':
self.sagemaker_session = LocalSession()
else:
self.sagemaker_session = sagemaker.session.Session(self.boto_session)
self.sagemaker_client = self.sagemaker_session.sagemaker_client