in common/sagemaker_rl/orchestrator/workflow/manager/model_manager.py [0:0]
def _update_model_table_training_states(self):
"""
Update the training states in the model table. This method
will poll the Sagemaker training job and then update
training job metadata of the model, including:
train_state,
s3_model_output_path,
training_start_time,
training_end_time
Args:
model_record (dict): Current model record in the
model table
"""
if self.model_record.model_in_terminal_state():
# model already in one of the final states
# need not do anything.
self.model_db_client.update_model_record(self._jsonify())
return self._jsonify()
# Else, try and fetch updated SageMaker TrainingJob status
sm_job_info = {}
for i in range(3):
try:
sm_job_info = self.sagemaker_client.describe_training_job(
TrainingJobName=self.model_id
)
except Exception as e:
if "ValidationException" in str(e):
if i >= 2:
# 3rd attempt for DescribeTrainingJob failed with ValidationException
logger.warn(
f"Looks like SageMaker Job was not submitted successfully."
f" Failing Training Job with ModelId {self.model_id}"
)
self.model_record.update_model_as_failed()
self.model_db_client.update_model_as_failed(self._jsonify())
return
else:
time.sleep(5)
continue
else:
# Do not raise exception, most probably throttling.
logger.warn(
f"Failed to check SageMaker Training Job state for ModelId {self.model_id}."
" This exception will be ignored, and retried."
)
logger.debug(e)
time.sleep(2)
return self._jsonify()
train_state = sm_job_info.get("TrainingJobStatus", "Pending")
training_start_time = sm_job_info.get("TrainingStartTime", None)
training_end_time = sm_job_info.get("TrainingEndTime", None)
if training_start_time is not None:
training_start_time = training_start_time.strftime("%Y-%m-%d %H:%M:%S")
if training_end_time is not None:
training_end_time = training_end_time.strftime("%Y-%m-%d %H:%M:%S")
model_artifacts = sm_job_info.get("ModelArtifacts", None)
if model_artifacts is not None:
s3_model_output_path = model_artifacts.get("S3ModelArtifacts", None)
else:
s3_model_output_path = None
self.model_record.update_model_job_status(
training_start_time, training_end_time, train_state, s3_model_output_path
)
self.model_db_client.update_model_job_state(self._jsonify())