in core/src/main/scala/com/microsoft/azure/synapse/ml/exploratory/FeatureBalanceMeasure.scala [48:132]
def this() = this(Identifiable.randomUID("FeatureBalanceMeasure"))
val featureNameCol = new Param[String](
this,
"featureNameCol",
"Output column name for feature names."
)
def getFeatureNameCol: String = $(featureNameCol)
def setFeatureNameCol(value: String): this.type = set(featureNameCol, value)
val classACol = new Param[String](
this,
"classACol",
"Output column name for the first feature value to compare."
)
def getClassACol: String = $(classACol)
def setClassACol(value: String): this.type = set(classACol, value)
val classBCol = new Param[String](
this,
"classBCol",
"Output column name for the second feature value to compare."
)
def getClassBCol: String = $(classBCol)
def setClassBCol(value: String): this.type = set(classBCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
setDefault(
featureNameCol -> "FeatureName",
classACol -> "ClassA",
classBCol -> "ClassB",
outputCol -> "FeatureBalanceMeasure"
)
override def transform(dataset: Dataset[_]): DataFrame = {
logTransform[DataFrame]({
validateSchema(dataset.schema)
val df = dataset
// Convert label into binary
// TODO (for v2): support regression scenarios
.withColumn(getLabelCol, when(col(getLabelCol).cast(LongType) > lit(0L), lit(1L)).otherwise(lit(0L)))
.cache
val Row(numRows: Double, numTrueLabels: Double) =
df.agg(count("*").cast(DoubleType), sum(getLabelCol).cast(DoubleType)).head
val positiveFeatureCountCol = DatasetExtensions.findUnusedColumnName("positiveFeatureCount", df.schema)
val featureCountCol = DatasetExtensions.findUnusedColumnName("featureCount", df.schema)
val positiveCountCol = DatasetExtensions.findUnusedColumnName("positiveCount", df.schema)
val rowCountCol = DatasetExtensions.findUnusedColumnName("rowCount", df.schema)
val featureValueCol = "FeatureValue"
val featureCounts = getSensitiveCols.map {
sensitiveCol =>
df
.groupBy(sensitiveCol)
.agg(
sum(getLabelCol).cast(DoubleType).alias(positiveFeatureCountCol),
count("*").cast(DoubleType).alias(featureCountCol)
)
.withColumn(positiveCountCol, lit(numTrueLabels))
.withColumn(rowCountCol, lit(numRows))
.withColumn(getFeatureNameCol, lit(sensitiveCol))
.withColumn(featureValueCol, col(sensitiveCol))
}.reduce(_ union _)
val metrics =
AssociationMetrics(positiveFeatureCountCol, featureCountCol, positiveCountCol, rowCountCol).toColumnMap
val associationMetricsDf = metrics.foldLeft(featureCounts) {
case (dfAcc, (metricName, metricFunc)) => dfAcc.withColumn(metricName, metricFunc)
}
df.unpersist
calculateParity(associationMetricsDf, featureValueCol)
})
}