def this()

in core/src/main/scala/com/microsoft/azure/synapse/ml/train/TrainClassifier.scala [53:189]


  def this() = this(Identifiable.randomUID("TrainClassifier"))

  /** Doc for model to run.
    */
  override def modelDoc: String = "Classifier to run"

  /** Specifies whether to reindex the given label column.
    * See class documentation for how this parameter interacts with specified labels.
    *
    * @group param
    */
  val reindexLabel = new BooleanParam(this, "reindexLabel", "Re-index the label column")
  setDefault(reindexLabel -> true)

  /** @group getParam */
  def getReindexLabel: Boolean = $(reindexLabel)

  /** @group setParam */
  def setReindexLabel(value: Boolean): this.type = set(reindexLabel, value)

  /** Specifies the labels metadata on the column.
    * See class documentation for how this parameter interacts with reindex labels parameter.
    *
    * @group param
    */
  val labels = new StringArrayParam(this, "labels", "Sorted label values on the labels column")

  /** @group getParam */
  def getLabels: Array[String] = $(labels)

  /** @group setParam */
  def setLabels(value: Array[String]): this.type = set(labels, value)

  /** Optional parameter, specifies the name of the features column passed to the learner.
    * Must have a unique name different from the input columns.
    * By default, set to <uid>_features.
    *
    * @group param
    */
  setDefault(featuresCol, this.uid + "_features")

  /** Fits the classification model.
    *
    * @param dataset The input dataset to train.
    * @return The trained classification model.
    */
  override def fit(dataset: Dataset[_]): TrainedClassifierModel = {
    logFit({
      val labelValues =
        if (isDefined(labels)) {
          Some(getLabels)
        } else {
          None
        }
      // Convert label column to categorical on train, remove rows with missing labels
      val (convertedLabelDataset, levels) = convertLabel(dataset, getLabelCol, labelValues)

      val (oneHotEncodeCategoricals, modifyInputLayer, numFeatures) = getFeaturizeParams

      var classifier: Estimator[_ <: PipelineStage] = getModel match {
        case logisticRegressionClassifier: LogisticRegression =>
          if (levels.isDefined && levels.get.length > 2) {
            new OneVsRest()
              .setClassifier(
                logisticRegressionClassifier
                  .setLabelCol(getLabelCol)
                  .setFeaturesCol(getFeaturesCol))
              .setLabelCol(getLabelCol)
              .setFeaturesCol(getFeaturesCol)
          } else {
            logisticRegressionClassifier
          }
        case gradientBoostedTreesClassifier: GBTClassifier =>
          if (levels.isDefined && levels.get.length > 2) {
            throw new Exception("Multiclass Gradient Boosted Tree Classifier not supported yet")
          } else {
            gradientBoostedTreesClassifier
          }
        case default@defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] =>
          default
        case _ => throw new Exception("Unsupported learner type " + getModel.getClass.toString)
      }

      classifier = classifier match {
        case predictor: Predictor[_, _, _] =>
          predictor
            .setLabelCol(getLabelCol)
            .setFeaturesCol(getFeaturesCol).asInstanceOf[Estimator[_ <: PipelineStage]]
        case default@defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] =>
          // assume label col and features col already set
          default
      }

      val featuresToHashTo =
        if (getNumFeatures != 0) {
          getNumFeatures
        } else {
          numFeatures
        }

      val featureColumns = convertedLabelDataset.columns.filter(col => col != getLabelCol).toSeq

      val featurizer = new Featurize()
        .setOutputCol(getFeaturesCol)
        .setInputCols(featureColumns.toArray)
        .setOneHotEncodeCategoricals(oneHotEncodeCategoricals)
        .setNumFeatures(featuresToHashTo)
      val featurizedModel = featurizer.fit(convertedLabelDataset)
      val processedData = featurizedModel.transform(convertedLabelDataset)

      processedData.cache()

      // For neural network, need to modify input layer so it will automatically work during train
      if (modifyInputLayer) {
        val multilayerPerceptronClassifier = classifier.asInstanceOf[MultilayerPerceptronClassifier]
        val row = processedData.take(1)(0)
        val featuresVector = row.get(row.fieldIndex(getFeaturesCol))
        val vectorSize = featuresVector.asInstanceOf[linalg.Vector].size
        multilayerPerceptronClassifier.getLayers(0) = vectorSize
        multilayerPerceptronClassifier.setLayers(multilayerPerceptronClassifier.getLayers)
      }

      // Train the learner
      val fitModel = classifier.fit(processedData)

      processedData.unpersist()

      // Note: The fit shouldn't do anything here
      val pipelineModel = new Pipeline().setStages(Array(featurizedModel, fitModel)).fit(convertedLabelDataset)
      val model = new TrainedClassifierModel()
        .setLabelCol(getLabelCol)
        .setModel(pipelineModel)
        .setFeaturesCol(getFeaturesCol)

      levels.map(l => model.setLevels(l.toArray)).getOrElse(model)
    })
  }