in common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]
def _update_experiment_db_training_workflow_metadata(self, training_workflow_metadata):
"""
Three thing happens here:
a) Checks if current TrainingWorkflowMetadata needs an update.
b) Fetches latest TrainingJob state from ModelDb for next_model_to_train
c) Updates ExperimentDb TrainingWorkflowMetadata with latest information.
d) Finally, updates the local ExperimentManager context to latest.
Args:
training_workflow_metadata (dict): A dictionary containing
training workflow related metadata
"""
if training_workflow_metadata is None:
# A training request hasn't been made yet.
# Nothing to proccess. Return.
return
next_model_to_train_id = training_workflow_metadata.get("next_model_to_train_id", None)
training_state = training_workflow_metadata.get("training_state", None)
if training_state is None:
# A training request hasn't been made yet.
# Nothing to proccess. Return.
return
elif not training_state.endswith("ING"):
# TrainingWorkflowState in ExperimentDb is already in a Final State.
# Sync thread only updates on Pending/Running TrainingWorkflowState state.
return
elif training_state.endswith("ING") and next_model_to_train_id is None:
# A training is in progress, but the training model-id is None!
logger.warn(
f"Model Training in {training_state}, while next_model_to_train_id is None. "
"Training Workflow would be stuck if this continues."
)
return
else:
# A training is in progress. Fetch the status of that training job from ModelDb.
training_job_record = self.model_db_client.get_model_record_with_retry(
self.experiment_id, next_model_to_train_id
)
# Get updated TrainingWorkflowState in {new_training_state}
if training_job_record is None:
# Training Job Record not found in ModelDb even with 1 retry, after 5 seconds.
# Most likely there was a failure creating requested TrainingJob
# Update the TrainingWorkflowState to Failed.
logger.warn(
f"ModelId {next_model_to_train_id} record not found. Failing the TrainingWorkflow"
)
new_training_state = TrainingState.FAILED
else:
train_state_from_modeldb = training_job_record.get("train_state")
if train_state_from_modeldb is not None:
new_training_state = TRAINING_JOB_STATUS_MAP[train_state_from_modeldb]
else:
# Since ModelDb training job state is None,
# keep the ExperimentDb TrainingWorkflowState same.
logger.warn(
f"ModelDb has model-id {next_model_to_train_id} 's state as 'None'. "
"Training Worklow would be stuck if this continues."
)
new_training_state = training_state
expected_next_model_to_train_id = next_model_to_train_id
# Generate new TrainingWorkflowState for ExperimentDb based on new_training_state
if new_training_state == TrainingState.TRAINED:
training_workflow_metadata["last_trained_model_id"] = next_model_to_train_id
training_workflow_metadata["next_model_to_train_id"] = None
training_workflow_metadata["training_state"] = new_training_state
elif (
new_training_state == TrainingState.FAILED
or new_training_state == TrainingState.STOPPED
):
# training_workflow_metadata['last_trained_model_id'] remains the same
# training_workflow_metadata['next_model_to_train_id'] remains the same or change to None
# update the ExperimentDb TrainingWorkflowState to Failed
training_workflow_metadata["training_state"] = new_training_state
else:
# training_workflow_metadata['last_trained_model_id'] remains the same
# training_workflow_metadata['next_model_to_train_id'] remains the same
# update the ExperimentDb TrainingWorkflowState to new_training_state
training_workflow_metadata["training_state"] = new_training_state
# Try to save the update in ExperimentDb
# This can update the status only if in the current record,
# next_model_to_train_id == expected_next_model_to_train_id
try:
self.exp_db_client.update_training_workflow_metadata_with_validation(
self.experiment_id, training_workflow_metadata, expected_next_model_to_train_id
)
except Exception as e:
if "ConditionalCheckFailedException" in str(e):
# Most likely Sync Thread went out of sync :(
# Just return here without updating local ExperimentManager.
logger.warn(
"Sync Thread trying to update ExperimentDb with old state. This should "
"get fixed in next run!"
)
return
logger.error("Failed to update ExperimentDb with latest information: " + str(e))
raise UnhandledWorkflowException(
"Some error occurred while update ExperimentDb record TrainingWorkflowMetadata"
)
# Finally, update local ExperimentManager with new states.
self.experiment_manager.experiment_record._last_trained_model_id = (
training_workflow_metadata["last_trained_model_id"]
)
self.experiment_manager.experiment_record._next_model_to_train_id = (
training_workflow_metadata["next_model_to_train_id"]
)
self.experiment_manager.experiment_record._training_state = training_workflow_metadata[
"training_state"
]