def train_next_model()

in common/sagemaker_rl/orchestrator/workflow/manager/experiment_manager.py [0:0]


    def train_next_model(self, wait=True, input_data_s3_prefix=None, input_model_id=None):
        """
        Train a new model given the training data and a pretrained model
        
        Args:
            wait (bool): Whether to wait until the training finish
            input_data_s3_prefix (str): S3 data path containing data
                used for the training job
            input_model_id (str): A model id to specify which model to
                use as a pre-trained model for the training job
        """
        # Sync experiment state if required
        self._sync_experiment_state_with_ddb()

        # use 'last_trained_model_id' by default as input model for next training
        if input_model_id is None and self.experiment_record._last_trained_model_id is not None:
            logger.info(f"Use last trained model {self.experiment_record._last_trained_model_id} "
                        "as pre-trained model for training")

            input_model_id = self.experiment_record._last_trained_model_id

        if input_model_id != self.experiment_record._last_trained_model_id:
            # No deployment if the given model is not ready
            if not self._check_if_model_ready(input_model_id):
                return

        # experiment only allows one training job at a time,
        # validate no other training request is in progress
        if self.experiment_record._training_state is not None and \
            self.experiment_record._training_state.endswith("ING"):
            logger.error(f"A training request with model id '{self.experiment_record._next_model_to_train_id}' "
                           f"was in the state of '{self.experiment_record._training_state}'. "
                           "Please wait until the training job is finished.")
            raise InvalidUsageException("Please wait for old Training Job to Complete before requesting a new one!")
        else:
            # update next_model_to_train_id and training state
            next_model_to_train_id = ModelManager.name_next_model(experiment_id=self.experiment_id)

            logger.info(f"Starting training job for ModelId '{next_model_to_train_id}''")

            self.exp_db_client.update_experiment_next_model_to_train_id(
                self.experiment_id,
                next_model_to_train_id)
            self.exp_db_client.update_experiment_training_state(
                self.experiment_id,
                TrainingState.PENDING)

            manifest_file_path = None
            if isinstance(input_data_s3_prefix, list):
                # generate manifest file and upload to s3 when having multiple inputs
                manifest_file_path = self._generate_manifest(input_data_s3_prefix)

            try:
                self.next_model_to_train = ModelManager(
                    model_db_client=self.model_db_client,
                    experiment_id=self.experiment_id,
                    model_id=next_model_to_train_id,
                    image=self.image,
                    role=self.resource_manager.iam_role_arn,
                    instance_config=self.resource_manager.training_fleet_config,
                    boto_session=self.boto_session,
                    algor_config=self.algor_config
                    )
                self.next_model_to_train.fit(wait=wait,
                                        input_model_id=input_model_id,
                                        input_data_s3_prefix=input_data_s3_prefix,
                                        manifest_file_path=manifest_file_path,
                                        logs=wait)
            except Exception as e:
                logger.error(e)
                pass

        # wait until exp ddb table updated
        if self.local_mode or wait:
            trained_state = self.experiment_record._training_state == TrainingState.TRAINED \
                            and self.experiment_record._last_trained_model_id == next_model_to_train_id \
                            and self.experiment_record._next_model_to_train_id is None
            num_retries = 0
            
            while not trained_state:
                # Sync experiment state if required
                self._sync_experiment_state_with_ddb()
                logger.debug("Waiting for experiment table training status to be updated...")
                time.sleep(2 * (2**num_retries))
                trained_state = self.experiment_record._training_state == TrainingState.TRAINED \
                                and self.experiment_record._last_trained_model_id == next_model_to_train_id \
                                and self.experiment_record._next_model_to_train_id is None
                num_retries += 1
                if num_retries >=5:
                    raise UnhandledWorkflowException(f"Training job '{self.experiment_record._next_model_to_train_id}' "
                    f"was in state of '{self.experiment_record._training_state}'. Expected it to be TRAINED.")
                if self.experiment_record._training_state == TrainingState.FAILED \
                    or self.experiment_record._training_state == TrainingState.STOPPED:
                    raise SageMakerTrainingJobException(f"Training job '{self.experiment_record._next_model_to_train_id}' "
                    f"ended in state of '{self.experiment_record._training_state}'. Please check Sagemaker logs for "
                    "more information.")