def this()

in core/src/main/scala/com/microsoft/azure/synapse/ml/exploratory/FeatureBalanceMeasure.scala [48:132]


  def this() = this(Identifiable.randomUID("FeatureBalanceMeasure"))

  val featureNameCol = new Param[String](
    this,
    "featureNameCol",
    "Output column name for feature names."
  )

  def getFeatureNameCol: String = $(featureNameCol)

  def setFeatureNameCol(value: String): this.type = set(featureNameCol, value)

  val classACol = new Param[String](
    this,
    "classACol",
    "Output column name for the first feature value to compare."
  )

  def getClassACol: String = $(classACol)

  def setClassACol(value: String): this.type = set(classACol, value)

  val classBCol = new Param[String](
    this,
    "classBCol",
    "Output column name for the second feature value to compare."
  )

  def getClassBCol: String = $(classBCol)

  def setClassBCol(value: String): this.type = set(classBCol, value)

  def setLabelCol(value: String): this.type = set(labelCol, value)

  setDefault(
    featureNameCol -> "FeatureName",
    classACol -> "ClassA",
    classBCol -> "ClassB",
    outputCol -> "FeatureBalanceMeasure"
  )

  override def transform(dataset: Dataset[_]): DataFrame = {
    logTransform[DataFrame]({
      validateSchema(dataset.schema)

      val df = dataset
        // Convert label into binary
        // TODO (for v2): support regression scenarios
        .withColumn(getLabelCol, when(col(getLabelCol).cast(LongType) > lit(0L), lit(1L)).otherwise(lit(0L)))
        .cache

      val Row(numRows: Double, numTrueLabels: Double) =
        df.agg(count("*").cast(DoubleType), sum(getLabelCol).cast(DoubleType)).head

      val positiveFeatureCountCol = DatasetExtensions.findUnusedColumnName("positiveFeatureCount", df.schema)
      val featureCountCol = DatasetExtensions.findUnusedColumnName("featureCount", df.schema)
      val positiveCountCol = DatasetExtensions.findUnusedColumnName("positiveCount", df.schema)
      val rowCountCol = DatasetExtensions.findUnusedColumnName("rowCount", df.schema)
      val featureValueCol = "FeatureValue"

      val featureCounts = getSensitiveCols.map {
        sensitiveCol =>
          df
            .groupBy(sensitiveCol)
            .agg(
              sum(getLabelCol).cast(DoubleType).alias(positiveFeatureCountCol),
              count("*").cast(DoubleType).alias(featureCountCol)
            )
            .withColumn(positiveCountCol, lit(numTrueLabels))
            .withColumn(rowCountCol, lit(numRows))
            .withColumn(getFeatureNameCol, lit(sensitiveCol))
            .withColumn(featureValueCol, col(sensitiveCol))
      }.reduce(_ union _)

      val metrics =
        AssociationMetrics(positiveFeatureCountCol, featureCountCol, positiveCountCol, rowCountCol).toColumnMap

      val associationMetricsDf = metrics.foldLeft(featureCounts) {
        case (dfAcc, (metricName, metricFunc)) => dfAcc.withColumn(metricName, metricFunc)
      }

      df.unpersist
      calculateParity(associationMetricsDf, featureValueCol)
    })
  }