override def fit()

in core/src/main/scala/com/microsoft/azure/synapse/ml/automl/TuneHyperparameters.scala [144:214]


  override def fit(dataset: Dataset[_]): TuneHyperparametersModel = {
    logFit({
      val sparkSession = dataset.sparkSession
      val splits = MLUtils.kFold(dataset.toDF.rdd, getNumFolds, getSeed)
      val hyperParams = getParamSpace.paramMaps
      val schema = dataset.schema
      val executionContext = getExecutionContext
      val (evaluationMetricColumnName, operator): (String, Ordering[Double]) =
        EvaluationUtils.getMetricWithOperator(getModels.head, getEvaluationMetric)
      val paramsPerRun = ListBuffer[ParamMap]()
      for (_ <- 0 until getNumRuns) {
        // Generate the new parameters, stepping through estimators sequentially
        paramsPerRun += hyperParams.next()
      }
      val numModels = getModels.length

      val metrics = splits.zipWithIndex.map { case ((training, validation), _) =>
        val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
        val validationDataset = sparkSession.createDataFrame(validation, schema).cache()

        val modelParams = ListBuffer[ParamMap]()
        for (n <- 0 until getNumRuns) {
          val params = paramsPerRun(n)
          modelParams += params
        }
        val foldMetricFutures = modelParams.zipWithIndex.map { case (paramMap, paramIndex) =>
          Future[Double] {
            val model = getModels(paramIndex % numModels).fit(trainingDataset, paramMap).asInstanceOf[Model[_]]
            val scoredDataset = model.transform(validationDataset, paramMap)
            val evaluator = new ComputeModelStatistics()
            evaluator.set(evaluator.evaluationMetric, getEvaluationMetric)
            model match {
              case _: TrainedRegressorModel =>
                logDebug("Evaluating trained regressor model.")
              case _: TrainedClassifierModel =>
                logDebug("Evaluating trained classifier model.")
              case classificationModel: ClassificationModel[_, _] =>
                logDebug(s"Evaluating SparkML ${model.uid} classification model.")
                evaluator
                  .setLabelCol(classificationModel.getLabelCol)
                  .setScoredLabelsCol(classificationModel.getPredictionCol)
                  .setScoresCol(classificationModel.getRawPredictionCol)
                if (getEvaluationMetric == MetricConstants.AllSparkMetrics)
                  evaluator.setEvaluationMetric(MetricConstants.ClassificationMetricsName)
              case regressionModel: RegressionModel[_, _] =>
                logDebug(s"Evaluating SparkML ${model.uid} regression model.")
                evaluator
                  .setLabelCol(regressionModel.getLabelCol)
                  .setScoredLabelsCol(regressionModel.getPredictionCol)
                if (getEvaluationMetric == MetricConstants.AllSparkMetrics)
                  evaluator.setEvaluationMetric(MetricConstants.RegressionMetricsName)
            }
            val metrics = evaluator.transform(scoredDataset)
            val metric = metrics.select(evaluationMetricColumnName).first()(0).toString.toDouble
            logDebug(s"Got metric $metric for model trained with $paramMap.")
            metric
          }(executionContext)
        }
        val foldMetrics = foldMetricFutures.toArray.map(awaitResult(_, Duration.Inf))

        trainingDataset.unpersist()
        validationDataset.unpersist()
        foldMetrics
      }.transpose.map(_.sum / $(numFolds)) // Calculate average metric over all splits

      val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)(operator)
      // Compute best model fit on dataset
      val bestModel = getModels(bestIndex % numModels).fit(dataset, paramsPerRun(bestIndex)).asInstanceOf[Model[_]]
      new TuneHyperparametersModel(uid).setBestModel(bestModel).setBestMetric(bestMetric)
    })
  }