in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerEstimator.scala [289:347]
override def fit(dataSet: Dataset[_]): SageMakerModel = {
transformSchema(dataSet.schema, logging = true)
val namePolicy = namePolicyFactory.createNamePolicy
val trainingJobName = namePolicy.trainingJobName
val conf = dataSet.sparkSession.sparkContext.getConf
val inputPath = resolveS3Path(trainingInputS3DataPath, trainingJobName, conf)
val startingS3UploadTime = this.timeProvider.currentTimeMillis
val dataUploadResults = trainingProjectedColumns match {
case Some(columns) if !columns.isEmpty => dataUploader.uploadData(inputPath,
dataSet.select(columns.head, columns.tail: _*))
case _ => dataUploader.uploadData(inputPath, dataSet)
}
val s3UploadTime = this.timeProvider.getElapsedTimeInSeconds(startingS3UploadTime)
log.info(s"S3 Upload Time: $s3UploadTime s")
try {
log.info(s"Creating training job with name $trainingJobName")
latestTrainingJob = Some(trainingJobName)
val createTrainingJobRequest = buildCreateTrainingJobRequest(trainingJobName,
dataUploadResults, conf)
log.info(s"CreateTrainingJobRequest: ${createTrainingJobRequest.toString}")
runTrainingJob(createTrainingJobRequest, trainingJobName)
} finally {
if (deleteStagingDataAfterTraining) {
log.info(s"Deleting training data ${inputPath.toS3UriString} of job with" +
s" name $trainingJobName")
deleteTrainingData(inputPath)
}
}
val describeTrainingJobRequest = new DescribeTrainingJobRequest()
.withTrainingJobName(trainingJobName)
InternalUtils.applyUserAgent(describeTrainingJobRequest)
val modelS3URI = sagemakerClient.describeTrainingJob(describeTrainingJobRequest)
.getModelArtifacts
.getS3ModelArtifacts
log.info(s"Model S3 URI: $modelS3URI")
new SageMakerModel(
Some(endpointInstanceType),
Some(endpointInitialInstanceCount),
requestRowSerializer,
responseRowDeserializer,
Option.empty,
Some(modelImage),
Some(S3DataPath.fromS3URI(modelS3URI)),
modelEnvironmentVariables,
Some(resolveRoleARN(sagemakerRole, conf).role),
endpointCreationPolicy,
sagemakerClient,
modelPrependInputRowsToTransformationRows,
namePolicy,
uid)
}