override def fit()

in sagemaker-spark-sdk/src/main/scala/com/amazonaws/services/sagemaker/sparksdk/SageMakerEstimator.scala [289:347]


  override def fit(dataSet: Dataset[_]): SageMakerModel = {
    transformSchema(dataSet.schema, logging = true)
    val namePolicy = namePolicyFactory.createNamePolicy
    val trainingJobName = namePolicy.trainingJobName

    val conf = dataSet.sparkSession.sparkContext.getConf
    val inputPath = resolveS3Path(trainingInputS3DataPath, trainingJobName, conf)

    val startingS3UploadTime = this.timeProvider.currentTimeMillis

    val dataUploadResults = trainingProjectedColumns match {
      case Some(columns) if !columns.isEmpty => dataUploader.uploadData(inputPath,
        dataSet.select(columns.head, columns.tail: _*))
      case _ => dataUploader.uploadData(inputPath, dataSet)
    }

    val s3UploadTime = this.timeProvider.getElapsedTimeInSeconds(startingS3UploadTime)
    log.info(s"S3 Upload Time: $s3UploadTime s")

    try {
      log.info(s"Creating training job with name $trainingJobName")
      latestTrainingJob = Some(trainingJobName)
      val createTrainingJobRequest = buildCreateTrainingJobRequest(trainingJobName,
        dataUploadResults, conf)

      log.info(s"CreateTrainingJobRequest: ${createTrainingJobRequest.toString}")
      runTrainingJob(createTrainingJobRequest, trainingJobName)
    } finally {
      if (deleteStagingDataAfterTraining) {
        log.info(s"Deleting training data ${inputPath.toS3UriString} of job with" +
          s" name $trainingJobName")
        deleteTrainingData(inputPath)
      }
    }

    val describeTrainingJobRequest = new DescribeTrainingJobRequest()
      .withTrainingJobName(trainingJobName)
    InternalUtils.applyUserAgent(describeTrainingJobRequest)
    val modelS3URI = sagemakerClient.describeTrainingJob(describeTrainingJobRequest)
        .getModelArtifacts
        .getS3ModelArtifacts

    log.info(s"Model S3 URI: $modelS3URI")
    new SageMakerModel(
      Some(endpointInstanceType),
      Some(endpointInitialInstanceCount),
      requestRowSerializer,
      responseRowDeserializer,
      Option.empty,
      Some(modelImage),
      Some(S3DataPath.fromS3URI(modelS3URI)),
      modelEnvironmentVariables,
      Some(resolveRoleARN(sagemakerRole, conf).role),
      endpointCreationPolicy,
      sagemakerClient,
      modelPrependInputRowsToTransformationRows,
      namePolicy,
      uid)
  }