in core/src/main/scala/com/microsoft/azure/synapse/ml/train/TrainClassifier.scala [53:189]
def this() = this(Identifiable.randomUID("TrainClassifier"))
/** Doc for model to run.
*/
override def modelDoc: String = "Classifier to run"
/** Specifies whether to reindex the given label column.
* See class documentation for how this parameter interacts with specified labels.
*
* @group param
*/
val reindexLabel = new BooleanParam(this, "reindexLabel", "Re-index the label column")
setDefault(reindexLabel -> true)
/** @group getParam */
def getReindexLabel: Boolean = $(reindexLabel)
/** @group setParam */
def setReindexLabel(value: Boolean): this.type = set(reindexLabel, value)
/** Specifies the labels metadata on the column.
* See class documentation for how this parameter interacts with reindex labels parameter.
*
* @group param
*/
val labels = new StringArrayParam(this, "labels", "Sorted label values on the labels column")
/** @group getParam */
def getLabels: Array[String] = $(labels)
/** @group setParam */
def setLabels(value: Array[String]): this.type = set(labels, value)
/** Optional parameter, specifies the name of the features column passed to the learner.
* Must have a unique name different from the input columns.
* By default, set to <uid>_features.
*
* @group param
*/
setDefault(featuresCol, this.uid + "_features")
/** Fits the classification model.
*
* @param dataset The input dataset to train.
* @return The trained classification model.
*/
override def fit(dataset: Dataset[_]): TrainedClassifierModel = {
logFit({
val labelValues =
if (isDefined(labels)) {
Some(getLabels)
} else {
None
}
// Convert label column to categorical on train, remove rows with missing labels
val (convertedLabelDataset, levels) = convertLabel(dataset, getLabelCol, labelValues)
val (oneHotEncodeCategoricals, modifyInputLayer, numFeatures) = getFeaturizeParams
var classifier: Estimator[_ <: PipelineStage] = getModel match {
case logisticRegressionClassifier: LogisticRegression =>
if (levels.isDefined && levels.get.length > 2) {
new OneVsRest()
.setClassifier(
logisticRegressionClassifier
.setLabelCol(getLabelCol)
.setFeaturesCol(getFeaturesCol))
.setLabelCol(getLabelCol)
.setFeaturesCol(getFeaturesCol)
} else {
logisticRegressionClassifier
}
case gradientBoostedTreesClassifier: GBTClassifier =>
if (levels.isDefined && levels.get.length > 2) {
throw new Exception("Multiclass Gradient Boosted Tree Classifier not supported yet")
} else {
gradientBoostedTreesClassifier
}
case default@defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] =>
default
case _ => throw new Exception("Unsupported learner type " + getModel.getClass.toString)
}
classifier = classifier match {
case predictor: Predictor[_, _, _] =>
predictor
.setLabelCol(getLabelCol)
.setFeaturesCol(getFeaturesCol).asInstanceOf[Estimator[_ <: PipelineStage]]
case default@defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] =>
// assume label col and features col already set
default
}
val featuresToHashTo =
if (getNumFeatures != 0) {
getNumFeatures
} else {
numFeatures
}
val featureColumns = convertedLabelDataset.columns.filter(col => col != getLabelCol).toSeq
val featurizer = new Featurize()
.setOutputCol(getFeaturesCol)
.setInputCols(featureColumns.toArray)
.setOneHotEncodeCategoricals(oneHotEncodeCategoricals)
.setNumFeatures(featuresToHashTo)
val featurizedModel = featurizer.fit(convertedLabelDataset)
val processedData = featurizedModel.transform(convertedLabelDataset)
processedData.cache()
// For neural network, need to modify input layer so it will automatically work during train
if (modifyInputLayer) {
val multilayerPerceptronClassifier = classifier.asInstanceOf[MultilayerPerceptronClassifier]
val row = processedData.take(1)(0)
val featuresVector = row.get(row.fieldIndex(getFeaturesCol))
val vectorSize = featuresVector.asInstanceOf[linalg.Vector].size
multilayerPerceptronClassifier.getLayers(0) = vectorSize
multilayerPerceptronClassifier.setLayers(multilayerPerceptronClassifier.getLayers)
}
// Train the learner
val fitModel = classifier.fit(processedData)
processedData.unpersist()
// Note: The fit shouldn't do anything here
val pipelineModel = new Pipeline().setStages(Array(featurizedModel, fitModel)).fit(convertedLabelDataset)
val model = new TrainedClassifierModel()
.setLabelCol(getLabelCol)
.setModel(pipelineModel)
.setFeaturesCol(getFeaturesCol)
levels.map(l => model.setLevels(l.toArray)).getOrElse(model)
})
}