in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerEstimator.scala [349:405]
private[sparksdk] def buildCreateTrainingJobRequest(trainingJobName: String,
dataUploadResults: DataUploadResult,
conf: SparkConf): CreateTrainingJobRequest = {
val createTrainingJobRequest = new CreateTrainingJobRequest()
InternalUtils.applyUserAgent(createTrainingJobRequest)
createTrainingJobRequest.withTrainingJobName(trainingJobName)
val algorithmSpecification = new AlgorithmSpecification()
algorithmSpecification.setTrainingImage(trainingImage)
algorithmSpecification.setTrainingInputMode(trainingInputMode)
createTrainingJobRequest.setAlgorithmSpecification(algorithmSpecification)
var hyperParameters = makeHyperParameters()
if (hyperParameters.isEmpty) {
hyperParameters = null
}
createTrainingJobRequest.withHyperParameters(hyperParameters)
val inputS3Path = dataUploadResults.s3DataPath
val inputDataSource = new DataSource()
.withS3DataSource(new S3DataSource().withS3Uri(inputS3Path.toS3UriString)
.withS3DataType(dataUploadResults match {
case ObjectPrefixUploadResult(_) => S3DataType.S3Prefix.toString
case ManifestDataUploadResult(_) => S3DataType.ManifestFile.toString
})
.withS3DataDistributionType(trainingS3DataDistribution))
val inputChannel = new Channel()
.withChannelName(trainingChannelName)
.withCompressionType(trainingCompressionCodec.orNull)
.withContentType(trainingContentType.orNull)
.withDataSource(inputDataSource)
createTrainingJobRequest.withInputDataConfig(inputChannel)
val outputDataConfig = new OutputDataConfig()
.withS3OutputPath(resolveS3Path(
trainingOutputS3DataPath,
trainingJobName,
conf).toS3UriString)
.withKmsKeyId(trainingKmsKeyId.orNull)
createTrainingJobRequest.withOutputDataConfig(outputDataConfig)
val resourceConfig = new ResourceConfig()
.withInstanceCount(trainingInstanceCount)
.withInstanceType(trainingInstanceType)
.withVolumeSizeInGB(trainingInstanceVolumeSizeInGB)
createTrainingJobRequest.withResourceConfig(resourceConfig)
createTrainingJobRequest.withRoleArn(resolveRoleARN(sagemakerRole, conf).role)
val stoppingCondition = new StoppingCondition()
.withMaxRuntimeInSeconds(trainingMaxRuntimeInSeconds)
createTrainingJobRequest.withStoppingCondition(stoppingCondition)
createTrainingJobRequest
}