def fromTrainingJob()

in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerModel.scala [79:122]


  def fromTrainingJob(trainingJobName: String,
                      modelImage: String,
                      modelExecutionRoleARN: String,
                      endpointInstanceType: String,
                      endpointInitialInstanceCount : Int,
                      requestRowSerializer: RequestRowSerializer,
                      responseRowDeserializer: ResponseRowDeserializer,
                      modelEnvironmentVariables: Map[String, String] = Map[String, String](),
                      endpointCreationPolicy: EndpointCreationPolicy =
                        EndpointCreationPolicy.CREATE_ON_CONSTRUCT,
                      sagemakerClient : AmazonSageMaker
                        = AmazonSageMakerClientBuilder.defaultClient,
                      prependResultRows : Boolean = true,
                      namePolicy : NamePolicy = new RandomNamePolicy(),
                      uid: String = Identifiable.randomUID("sagemaker")) : SageMakerModel = {
    require(endpointCreationPolicy != EndpointCreationPolicy.DO_NOT_CREATE,
      "Endpoint creation policy must not be DO_NOT_CREATE to" +
        " create an endpoint from a training job name.")

    val describeTrainingJobRequest = new DescribeTrainingJobRequest()
      .withTrainingJobName(trainingJobName)
    InternalUtils.applyUserAgent(describeTrainingJobRequest)
    val response = sagemakerClient.describeTrainingJob(describeTrainingJobRequest)

    val status = TrainingJobStatus.fromValue(response.getTrainingJobStatus)
    require(status == TrainingJobStatus.Completed || status == TrainingJobStatus.Stopped,
      "Can only create a SageMakerModel from a training job with status" +
        " Completed or Stopped, not status " + status.toString)

    val modelPath = response.getModelArtifacts.getS3ModelArtifacts
    new SageMakerModel(modelImage = Some(modelImage),
      modelPath = Some(S3DataPath.fromS3URI(modelPath)),
      requestRowSerializer = requestRowSerializer,
      responseRowDeserializer = responseRowDeserializer,
      modelEnvironmentVariables = modelEnvironmentVariables,
      modelExecutionRoleARN = Some(modelExecutionRoleARN),
      endpointCreationPolicy = endpointCreationPolicy,
      endpointInstanceType = Some(endpointInstanceType),
      endpointInitialInstanceCount = Some(endpointInitialInstanceCount),
      sagemakerClient = sagemakerClient,
      prependResultRows = prependResultRows,
      namePolicy = namePolicy,
      uid = uid)
  }