in core/src/main/scala/com/microsoft/azure/synapse/ml/train/TrainRegressor.scala [24:119]
def this() = this(Identifiable.randomUID("TrainRegressor"))
/** Doc for model to run.
*/
override def modelDoc: String = "Regressor to run"
/** Optional parameter, specifies the name of the features column passed to the learner.
* Must have a unique name different from the input columns.
* By default, set to <uid>_features.
* @group param
*/
setDefault(featuresCol, this.uid + "_features")
/** Fits the regression model.
*
* @param dataset The input dataset to train.
* @return The trained regression model.
*/
override def fit(dataset: Dataset[_]): TrainedRegressorModel = {
logFit({
val labelColumn = getLabelCol
var oneHotEncodeCategoricals = true
val numFeatures: Int = getModel match {
case _: DecisionTreeRegressor | _: GBTRegressor | _: RandomForestRegressor =>
oneHotEncodeCategoricals = false
FeaturizeUtilities.NumFeaturesTreeOrNNBased
case _ =>
FeaturizeUtilities.NumFeaturesDefault
}
val regressor = getModel match {
case predictor: Predictor[_, _, _] =>
predictor
.setLabelCol(getLabelCol)
.setFeaturesCol(getFeaturesCol).asInstanceOf[Estimator[_ <: PipelineStage]]
case default@defaultType if defaultType.isInstanceOf[Estimator[_ <: PipelineStage]] =>
// assume label col and features col already set
default
case _ => throw new Exception("Unsupported learner type " + getModel.getClass.toString)
}
val featuresToHashTo =
if (getNumFeatures != 0) {
getNumFeatures
} else {
numFeatures
}
// TODO: Handle DateType, TimestampType and DecimalType for label
// Convert the label column during train to the correct type and drop missings
val convertedLabelDataset = dataset.withColumn(labelColumn,
dataset.schema(labelColumn).dataType match {
case _: IntegerType |
_: BooleanType |
_: FloatType |
_: ByteType |
_: LongType |
_: ShortType =>
dataset(labelColumn).cast(DoubleType)
case _: StringType =>
throw new Exception("Invalid type: "
+ "Regressors are not able to train on a string label column: " + labelColumn)
case _: DoubleType =>
dataset(labelColumn)
case default => throw new Exception("Unknown type: " + default.typeName +
", for label column: " + labelColumn)
}
).na.drop(Seq(labelColumn))
val featureColumns = convertedLabelDataset.columns.filter(col => col != labelColumn).toSeq
val featurizer = new Featurize()
.setOutputCol(getFeaturesCol)
.setInputCols(featureColumns.toArray)
.setOneHotEncodeCategoricals(oneHotEncodeCategoricals)
.setNumFeatures(featuresToHashTo)
val featurizedModel = featurizer.fit(convertedLabelDataset)
val processedData = featurizedModel.transform(convertedLabelDataset)
processedData.cache()
// Train the learner
val fitModel = regressor.fit(processedData)
processedData.unpersist()
// Note: The fit shouldn't do anything here
val pipelineModel = new Pipeline().setStages(Array(featurizedModel, fitModel)).fit(convertedLabelDataset)
new TrainedRegressorModel()
.setLabelCol(labelColumn)
.setModel(pipelineModel)
.setFeaturesCol(getFeaturesCol)
})
}