in core/src/main/scala/com/microsoft/azure/synapse/ml/train/ComputeModelStatistics.scala [63:170]
def this() = this(Identifiable.randomUID("ComputeModelStatistics"))
/** The ROC curve evaluated for a binary classifier. */
var rocCurve: DataFrame = _
lazy val metricsLogger = new MetricsLogger(uid)
/** Calculates the metrics for the given dataset and model.
* @param dataset the dataset to calculate the metrics for
* @return DataFrame whose columns contain the calculated metrics
*/
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
val (modelName, labelColumnName, scoreValueKind) =
MetricUtils.getSchemaInfo(
dataset.schema,
if (isDefined(labelCol)) Some(getLabelCol) else None,
getEvaluationMetric)
// For creating the result dataframe in classification or regression case
val spark = dataset.sparkSession
import spark.implicits._
if (scoreValueKind == SchemaConstants.ClassificationKind) {
var resultDF: DataFrame =
Seq(MetricConstants.ClassificationEvaluationType)
.toDF(MetricConstants.EvaluationType)
val scoredLabelsColumnName =
if (isDefined(scoredLabelsCol)) getScoredLabelsCol
else SparkSchema.getSparkPredictionColumnName(dataset.schema, modelName)
// Get levels for label column if categorical
val levels = CategoricalUtilities.getLevels(dataset.schema, labelColumnName)
val levelsExist = levels.isDefined
lazy val levelsToIndexMap: Map[Any, Double] = getLevelsToIndexMap(levels.get)
lazy val predictionAndLabels =
if (levelsExist)
getPredictionAndLabels(dataset, labelColumnName, scoredLabelsColumnName, levelsToIndexMap)
else
selectAndCastToRDD(dataset, scoredLabelsColumnName, labelColumnName)
lazy val scoresAndLabels = {
val scoresColumnName =
if (isDefined(scoresCol)) getScoresCol
else SparkSchema.getSparkRawPredictionColumnName(dataset.schema, modelName)
if (scoresColumnName == null) predictionAndLabels
else if (levelsExist) getScoresAndLabels(dataset, labelColumnName, scoresColumnName, levelsToIndexMap)
else getScalarScoresAndLabels(dataset, labelColumnName, scoresColumnName)
}
lazy val (labels: Array[Double], confusionMatrix: Matrix) = createConfusionMatrix(predictionAndLabels)
// If levels exist, use the extra information they give to get better performance
getEvaluationMetric match {
case allMetrics if allMetrics == MetricConstants.AllSparkMetrics ||
allMetrics == MetricConstants.ClassificationMetricsName =>
resultDF = addConfusionMatrixToResult(labels, confusionMatrix, resultDF)
resultDF = addAllClassificationMetrics(
modelName, dataset, labelColumnName, predictionAndLabels,
confusionMatrix, scoresAndLabels, resultDF)
case simpleMetric if simpleMetric == MetricConstants.AccuracySparkMetric ||
simpleMetric == MetricConstants.PrecisionSparkMetric ||
simpleMetric == MetricConstants.RecallSparkMetric =>
resultDF = addSimpleMetric(simpleMetric, predictionAndLabels, resultDF)
case MetricConstants.AucSparkMetric =>
val numLevels = if (levelsExist) levels.get.length
else confusionMatrix.numRows
if (numLevels <= 2) {
// Add the AUC
val auc: Double = getAUC(modelName, dataset, labelColumnName, scoresAndLabels)
resultDF = resultDF.withColumn(MetricConstants.AucColumnName, lit(auc))
} else {
throw new Exception("Error: AUC is not available for multiclass case")
}
case default =>
throw new Exception(s"Error: $default is not a classification metric")
}
resultDF
} else if (scoreValueKind == SchemaConstants.RegressionKind) {
val scoresColumnName =
if (isDefined(scoresCol)) getScoresCol
else SparkSchema.getSparkPredictionColumnName(dataset.schema, modelName)
val scoresAndLabels = selectAndCastToRDD(dataset, scoresColumnName, labelColumnName)
val regressionMetrics = new RegressionMetrics(scoresAndLabels)
// get all spark metrics possible: "mse", "rmse", "r2", "mae"
val mse = regressionMetrics.meanSquaredError
val rmse = regressionMetrics.rootMeanSquaredError
val r2 = regressionMetrics.r2
val mae = regressionMetrics.meanAbsoluteError
metricsLogger.logRegressionMetrics(mse, rmse, r2, mae)
Seq((mse, rmse, r2, mae)).toDF(MetricConstants.MseColumnName,
MetricConstants.RmseColumnName,
MetricConstants.R2ColumnName,
MetricConstants.MaeColumnName)
} else {
throwOnInvalidScoringKind(scoreValueKind)
}
})
}