def _update_experiment_db_training_workflow_metadata()

in common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]


    def _update_experiment_db_training_workflow_metadata(self, training_workflow_metadata):
        """
        Three thing happens here:
        a) Checks if current TrainingWorkflowMetadata needs an update.
        b) Fetches latest TrainingJob state from ModelDb for next_model_to_train
        c) Updates ExperimentDb TrainingWorkflowMetadata with latest information.
        d) Finally, updates the local ExperimentManager context to latest.

        Args:
            training_workflow_metadata (dict): A dictionary containing
                training workflow related metadata
        """
        if training_workflow_metadata is None:
            # A training request hasn't been made yet.
            # Nothing to proccess. Return.
            return

        next_model_to_train_id = training_workflow_metadata.get("next_model_to_train_id", None)
        training_state = training_workflow_metadata.get("training_state", None)

        if training_state is None:
            # A training request hasn't been made yet.
            # Nothing to proccess. Return.
            return
        elif not training_state.endswith("ING"):
            # TrainingWorkflowState in ExperimentDb is already in a Final State.
            # Sync thread only updates on Pending/Running TrainingWorkflowState state.
            return
        elif training_state.endswith("ING") and next_model_to_train_id is None:
            # A training is in progress, but the training model-id is None!
            logger.warn(
                f"Model Training in {training_state}, while next_model_to_train_id is None. "
                "Training Workflow would be stuck if this continues."
            )
            return
        else:
            # A training is in progress. Fetch the status of that training job from ModelDb.
            training_job_record = self.model_db_client.get_model_record_with_retry(
                self.experiment_id, next_model_to_train_id
            )

            # Get updated TrainingWorkflowState in {new_training_state}
            if training_job_record is None:
                # Training Job Record not found in ModelDb even with 1 retry, after 5 seconds.
                # Most likely there was a failure creating requested TrainingJob
                # Update the TrainingWorkflowState to Failed.
                logger.warn(
                    f"ModelId {next_model_to_train_id} record not found. Failing the TrainingWorkflow"
                )
                new_training_state = TrainingState.FAILED
            else:
                train_state_from_modeldb = training_job_record.get("train_state")

                if train_state_from_modeldb is not None:
                    new_training_state = TRAINING_JOB_STATUS_MAP[train_state_from_modeldb]
                else:
                    # Since ModelDb training job state is None,
                    # keep the ExperimentDb TrainingWorkflowState same.
                    logger.warn(
                        f"ModelDb has model-id {next_model_to_train_id} 's state as 'None'. "
                        "Training Worklow would be stuck if this continues."
                    )
                    new_training_state = training_state

        expected_next_model_to_train_id = next_model_to_train_id
        # Generate new TrainingWorkflowState for ExperimentDb based on new_training_state
        if new_training_state == TrainingState.TRAINED:
            training_workflow_metadata["last_trained_model_id"] = next_model_to_train_id
            training_workflow_metadata["next_model_to_train_id"] = None
            training_workflow_metadata["training_state"] = new_training_state

        elif (
            new_training_state == TrainingState.FAILED
            or new_training_state == TrainingState.STOPPED
        ):
            # training_workflow_metadata['last_trained_model_id'] remains the same
            # training_workflow_metadata['next_model_to_train_id'] remains the same or change to None
            # update the ExperimentDb TrainingWorkflowState to Failed
            training_workflow_metadata["training_state"] = new_training_state
        else:
            # training_workflow_metadata['last_trained_model_id'] remains the same
            # training_workflow_metadata['next_model_to_train_id'] remains the same
            # update the ExperimentDb TrainingWorkflowState to new_training_state
            training_workflow_metadata["training_state"] = new_training_state

        # Try to save the update in ExperimentDb
        # This can update the status only if in the current record,
        # next_model_to_train_id == expected_next_model_to_train_id
        try:
            self.exp_db_client.update_training_workflow_metadata_with_validation(
                self.experiment_id, training_workflow_metadata, expected_next_model_to_train_id
            )
        except Exception as e:
            if "ConditionalCheckFailedException" in str(e):
                # Most likely Sync Thread went out of sync :(
                # Just return here without updating local ExperimentManager.
                logger.warn(
                    "Sync Thread trying to update ExperimentDb with old state. This should "
                    "get fixed in next run!"
                )
                return
            logger.error("Failed to update ExperimentDb with latest information: " + str(e))
            raise UnhandledWorkflowException(
                "Some error occurred while update ExperimentDb record TrainingWorkflowMetadata"
            )

        # Finally, update local ExperimentManager with new states.
        self.experiment_manager.experiment_record._last_trained_model_id = (
            training_workflow_metadata["last_trained_model_id"]
        )
        self.experiment_manager.experiment_record._next_model_to_train_id = (
            training_workflow_metadata["next_model_to_train_id"]
        )
        self.experiment_manager.experiment_record._training_state = training_workflow_metadata[
            "training_state"
        ]