in common/sagemaker_rl/orchestrator/workflow/manager/model_manager.py [0:0]
def _update_model_table_evaluation_states(self):
"""Update the evaluation states in the model table. This method
will poll the Sagemaker evaluation job and then update
evaluation job metadata of the model, including:
eval_state,
eval_scores
Args:
model_record (dict): Current model record in the
model table
"""
if self.model_record.eval_in_terminal_state():
self.model_db_client.update_model_record(self._jsonify())
return self._jsonify()
# Try and fetch updated SageMaker Training Job Status
sm_eval_job_info = {}
for i in range(3):
try:
sm_eval_job_info = self.sagemaker_client.describe_training_job(
TrainingJobName=self.model_record._evaluation_job_name
)
except Exception as e:
if "ValidationException" in str(e):
if i >= 2:
# 3rd attempt for DescribeTrainingJob with validation failure
logger.warn(
"Looks like SageMaker Job was not submitted successfully."
f" Failing EvaluationJob {self.model_record._evaluation_job_name}"
)
self.model_record.update_eval_job_as_failed()
self.model_db_client.update_model_eval_as_failed(self._jsonify())
return
else:
time.sleep(5)
continue
else:
# Do not raise exception, most probably throttling.
logger.warn(
"Failed to check SageMaker Training Job state for EvaluationJob: "
f" {self.model_record._evaluation_job_name}. This exception will be ignored,"
" and retried."
)
time.sleep(2)
return self._jsonify()
eval_state = sm_eval_job_info.get("TrainingJobStatus", "Pending")
if eval_state == "Completed":
eval_score = "n.a."
if self.local_mode:
rgx = re.compile(
"average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$", re.M
)
eval_score_rgx = rgx.findall(self.log_output)
if len(eval_score_rgx) == 0:
logger.warning("No eval score available from vw job log.")
else:
eval_score = eval_score_rgx[0][0] # [('eval_score', '')]
else:
attempts = 0
while eval_score == "n.a." and attempts < 4:
try:
metric_df = TrainingJobAnalytics(
self.model_record._evaluation_job_name, ["average_loss"]
).dataframe()
eval_score = str(
metric_df[metric_df["metric_name"] == "average_loss"]["value"][0]
)
except Exception:
# to avoid throttling
time.sleep(5)
continue
attempts += 1
self.model_record._eval_state = eval_state
self.model_record.add_model_eval_scores(eval_score)
self.model_db_client.update_model_eval_job_state(self._jsonify())
else:
# update eval state via ddb client
self.model_record.update_eval_job_state(eval_state)
self.model_db_client.update_model_eval_job_state(self._jsonify())