in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerEstimator.scala [422:458]
private def awaitTrainingCompletion(trainingJobName : String) : Unit = {
val startTime = this.timeProvider.currentTimeMillis
val describeTrainingJobRequest = new DescribeTrainingJobRequest()
.withTrainingJobName(trainingJobName)
InternalUtils.applyUserAgent(describeTrainingJobRequest)
log.info(s"Begin waiting for training job $trainingJobName")
while (this.timeProvider.currentTimeMillis - startTime < trainingJobTimeout.toMillis) {
try {
val response = sagemakerClient.describeTrainingJob(describeTrainingJobRequest)
val currentStatus = TrainingJobStatus.fromValue(response.getTrainingJobStatus)
log.info(s"Training job status: $currentStatus")
currentStatus match {
case TrainingJobStatus.Completed => return
case TrainingJobStatus.Failed =>
val message = s"Training job '$trainingJobName' failed for reason:" +
s" '${response.getFailureReason}'"
throw new RuntimeException(message)
case TrainingJobStatus.Stopped =>
val message = s"Training job '$trainingJobName' stopped. Stopping condition:" +
s" '${response.getStoppingCondition}'"
throw new RuntimeException(message)
case _ => // for any other statuses, continue polling
}
} catch {
case e : SdkBaseException =>
if (!RetryUtils.isRetryableServiceException(e)) {
throw e
}
log.warn(s"Retryable exception: ${e.getMessage}", e)
case t : Throwable => throw t
}
timeProvider.sleep(SageMakerEstimator.TrainingJobPollInterval.toMillis)
}
throw new RuntimeException(s"Timed out after ${trainingJobTimeout.toString} while waiting for" +
s" Training Job '$trainingJobName' to finish creating.")
}