in src/main/scala/com/amazon/deequ/analyzers/MutualInformation.scala [39:87]
override def computeMetricFrom(state: Option[FrequenciesAndNumRows]): DoubleMetric = {
state match {
case Some(theState) =>
val total = theState.numRows
val Seq(col1, col2) = columns
val freqCol1 = s"__deequ_f1_$col1"
val freqCol2 = s"__deequ_f2_$col2"
val jointStats = theState.frequencies
val marginalStats1 = jointStats
.select(col1, COUNT_COL)
.groupBy(col1)
.agg(sum(COUNT_COL).as(freqCol1))
val marginalStats2 = jointStats
.select(col2, COUNT_COL)
.groupBy(col2)
.agg(sum(COUNT_COL).as(freqCol2))
val miUdf = udf {
(px: Double, py: Double, pxy: Double) =>
(pxy / total) * math.log((pxy / total) / ((px / total) * (py / total)))
}
val miCol = s"__deequ_mi_${col1}_$col2"
val value = jointStats
.join(marginalStats1, usingColumn = col1)
.join(marginalStats2, usingColumn = col2)
.withColumn(miCol, miUdf(col(freqCol1), col(freqCol2), col(COUNT_COL)))
.agg(sum(miCol))
val resultRow = value.head()
if (resultRow.isNullAt(0)) {
metricFromEmpty(this, "MutualInformation", columns.mkString(","), Entity.Mutlicolumn)
} else {
metricFromValue(resultRow.getDouble(0), "MutualInformation", columns.mkString(","),
Entity.Mutlicolumn)
}
case None =>
metricFromEmpty(this, "MutualInformation", columns.mkString(","), Entity.Mutlicolumn)
}
}