in archived/rl_gamerserver_ray/common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]
def train_next_model(self, wait=True, input_data_s3_prefix=None, input_model_id=None):
"""
Train a new model given the training data and a pretrained model
Args:
wait (bool): Whether to wait until the training finish
input_data_s3_prefix (str): S3 data path containing data
used for the training job
input_model_id (str): A model id to specify which model to
use as a pre-trained model for the training job
"""
# Sync experiment state if required
self._sync_experiment_state_with_ddb()
# use 'last_trained_model_id' by default as input model for next training
if input_model_id is None and self.experiment_record._last_trained_model_id is not None:
logger.info(
f"Use last trained model {self.experiment_record._last_trained_model_id} "
"as pre-trained model for training"
)
input_model_id = self.experiment_record._last_trained_model_id
if input_model_id != self.experiment_record._last_trained_model_id:
# No deployment if the given model is not ready
if not self._check_if_model_ready(input_model_id):
return
# experiment only allows one training job at a time,
# validate no other training request is in progress
if (
self.experiment_record._training_state is not None
and self.experiment_record._training_state.endswith("ING")
):
logger.error(
f"A training request with model id '{self.experiment_record._next_model_to_train_id}' "
f"was in the state of '{self.experiment_record._training_state}'. "
"Please wait until the training job is finished."
)
raise InvalidUsageException(
"Please wait for old Training Job to Complete before requesting a new one!"
)
else:
# update next_model_to_train_id and training state
next_model_to_train_id = ModelManager.name_next_model(experiment_id=self.experiment_id)
logger.info(f"Starting training job for ModelId '{next_model_to_train_id}''")
self.exp_db_client.update_experiment_next_model_to_train_id(
self.experiment_id, next_model_to_train_id
)
self.exp_db_client.update_experiment_training_state(
self.experiment_id, TrainingState.PENDING
)
manifest_file_path = None
if isinstance(input_data_s3_prefix, list):
# generate manifest file and upload to s3 when having multiple inputs
manifest_file_path = self._generate_manifest(input_data_s3_prefix)
try:
self.next_model_to_train = ModelManager(
model_db_client=self.model_db_client,
experiment_id=self.experiment_id,
model_id=next_model_to_train_id,
image=self.image,
role=self.resource_manager.iam_role_arn,
instance_config=self.resource_manager.training_fleet_config,
boto_session=self.boto_session,
algor_config=self.algor_config,
)
self.next_model_to_train.fit(
wait=wait,
input_model_id=input_model_id,
input_data_s3_prefix=input_data_s3_prefix,
manifest_file_path=manifest_file_path,
logs=wait,
)
except Exception as e:
logger.error(e)
pass
# wait until exp ddb table updated
if self.local_mode or wait:
trained_state = (
self.experiment_record._training_state == TrainingState.TRAINED
and self.experiment_record._last_trained_model_id == next_model_to_train_id
and self.experiment_record._next_model_to_train_id is None
)
num_retries = 0
while not trained_state:
# Sync experiment state if required
self._sync_experiment_state_with_ddb()
logger.debug("Waiting for experiment table training status to be updated...")
time.sleep(2 * (2 ** num_retries))
trained_state = (
self.experiment_record._training_state == TrainingState.TRAINED
and self.experiment_record._last_trained_model_id == next_model_to_train_id
and self.experiment_record._next_model_to_train_id is None
)
num_retries += 1
if num_retries >= 5:
raise UnhandledWorkflowException(
f"Training job '{self.experiment_record._next_model_to_train_id}' "
f"was in state of '{self.experiment_record._training_state}'. Expected it to be TRAINED."
)
if (
self.experiment_record._training_state == TrainingState.FAILED
or self.experiment_record._training_state == TrainingState.STOPPED
):
raise SageMakerTrainingJobException(
f"Training job '{self.experiment_record._next_model_to_train_id}' "
f"ended in state of '{self.experiment_record._training_state}'. Please check Sagemaker logs for "
"more information."
)