in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerModel.scala [79:122]
def fromTrainingJob(trainingJobName: String,
modelImage: String,
modelExecutionRoleARN: String,
endpointInstanceType: String,
endpointInitialInstanceCount : Int,
requestRowSerializer: RequestRowSerializer,
responseRowDeserializer: ResponseRowDeserializer,
modelEnvironmentVariables: Map[String, String] = Map[String, String](),
endpointCreationPolicy: EndpointCreationPolicy =
EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
sagemakerClient : AmazonSageMaker
= AmazonSageMakerClientBuilder.defaultClient,
prependResultRows : Boolean = true,
namePolicy : NamePolicy = new RandomNamePolicy(),
uid: String = Identifiable.randomUID("sagemaker")) : SageMakerModel = {
require(endpointCreationPolicy != EndpointCreationPolicy.DO_NOT_CREATE,
"Endpoint creation policy must not be DO_NOT_CREATE to" +
" create an endpoint from a training job name.")
val describeTrainingJobRequest = new DescribeTrainingJobRequest()
.withTrainingJobName(trainingJobName)
InternalUtils.applyUserAgent(describeTrainingJobRequest)
val response = sagemakerClient.describeTrainingJob(describeTrainingJobRequest)
val status = TrainingJobStatus.fromValue(response.getTrainingJobStatus)
require(status == TrainingJobStatus.Completed || status == TrainingJobStatus.Stopped,
"Can only create a SageMakerModel from a training job with status" +
" Completed or Stopped, not status " + status.toString)
val modelPath = response.getModelArtifacts.getS3ModelArtifacts
new SageMakerModel(modelImage = Some(modelImage),
modelPath = Some(S3DataPath.fromS3URI(modelPath)),
requestRowSerializer = requestRowSerializer,
responseRowDeserializer = responseRowDeserializer,
modelEnvironmentVariables = modelEnvironmentVariables,
modelExecutionRoleARN = Some(modelExecutionRoleARN),
endpointCreationPolicy = endpointCreationPolicy,
endpointInstanceType = Some(endpointInstanceType),
endpointInitialInstanceCount = Some(endpointInitialInstanceCount),
sagemakerClient = sagemakerClient,
prependResultRows = prependResultRows,
namePolicy = namePolicy,
uid = uid)
}