def _update_model_table_evaluation_states()

in common/sagemaker_rl/orchestrator/workflow/manager/model_manager.py [0:0]


    def _update_model_table_evaluation_states(self):
        """Update the evaluation states in the model table. This method
        will poll the Sagemaker evaluation job and then update
        evaluation job metadata of the model, including:
            eval_state,
            eval_scores

        Args:
            model_record (dict): Current model record in the
                model table
        """

        if self.model_record.eval_in_terminal_state():
            self.model_db_client.update_model_record(self._jsonify())
            return self._jsonify()

        # Try and fetch updated SageMaker Training Job Status
        sm_eval_job_info = {}
        for i in range(3):
            try:
                sm_eval_job_info = self.sagemaker_client.describe_training_job(
                    TrainingJobName=self.model_record._evaluation_job_name
                )
            except Exception as e:
                if "ValidationException" in str(e):
                    if i >= 2:
                        # 3rd attempt for DescribeTrainingJob with validation failure
                        logger.warn(
                            "Looks like SageMaker Job was not submitted successfully."
                            f" Failing EvaluationJob {self.model_record._evaluation_job_name}"
                        )
                        self.model_record.update_eval_job_as_failed()
                        self.model_db_client.update_model_eval_as_failed(self._jsonify())
                        return
                    else:
                        time.sleep(5)
                        continue
                else:
                    # Do not raise exception, most probably throttling.
                    logger.warn(
                        "Failed to check SageMaker Training Job state for EvaluationJob: "
                        f" {self.model_record._evaluation_job_name}. This exception will be ignored,"
                        " and retried."
                    )
                    time.sleep(2)
                    return self._jsonify()

        eval_state = sm_eval_job_info.get("TrainingJobStatus", "Pending")
        if eval_state == "Completed":
            eval_score = "n.a."

            if self.local_mode:
                rgx = re.compile(
                    "average loss = ([-+]?[0-9]*\\.?[0-9]+([eE][-+]?[0-9]+)?).*$", re.M
                )
                eval_score_rgx = rgx.findall(self.log_output)

                if len(eval_score_rgx) == 0:
                    logger.warning("No eval score available from vw job log.")
                else:
                    eval_score = eval_score_rgx[0][0]  # [('eval_score', '')]
            else:
                attempts = 0
                while eval_score == "n.a." and attempts < 4:
                    try:
                        metric_df = TrainingJobAnalytics(
                            self.model_record._evaluation_job_name, ["average_loss"]
                        ).dataframe()
                        eval_score = str(
                            metric_df[metric_df["metric_name"] == "average_loss"]["value"][0]
                        )
                    except Exception:
                        # to avoid throttling
                        time.sleep(5)
                        continue
                    attempts += 1
            self.model_record._eval_state = eval_state
            self.model_record.add_model_eval_scores(eval_score)
            self.model_db_client.update_model_eval_job_state(self._jsonify())
        else:
            # update eval state via ddb client
            self.model_record.update_eval_job_state(eval_state)
            self.model_db_client.update_model_eval_job_state(self._jsonify())