def compute()

in python/pyspark/pandas/correlation.py [0:0]


def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkDataFrame:
    """
    Compute correlation per group, excluding NA/null values.

    Input PySpark Dataframe should contain column `CORRELATION_VALUE_1_COLUMN` and
    column `CORRELATION_VALUE_2_COLUMN`, as well as the group columns.

    The returned PySpark Dataframe will contain the correlation column
    `CORRELATION_CORR_OUTPUT_COLUMN` and the non-null count column
    `CORRELATION_COUNT_OUTPUT_COLUMN`, as well as the group columns.
    """
    assert len(groupKeys) > 0
    assert method in ["pearson", "spearman", "kendall"]

    sdf = sdf.select(
        *[F.col(key) for key in groupKeys],
        *[
            # assign both columns nulls, if some of them are null
            F.when(
                F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN),
                F.lit(None),
            )
            .otherwise(F.col(CORRELATION_VALUE_1_COLUMN))
            .alias(CORRELATION_VALUE_1_COLUMN),
            F.when(
                F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN),
                F.lit(None),
            )
            .otherwise(F.col(CORRELATION_VALUE_2_COLUMN))
            .alias(CORRELATION_VALUE_2_COLUMN),
        ],
    )

    if method in ["pearson", "spearman"]:
        # convert values to avg ranks for spearman correlation
        if method == "spearman":
            ROW_NUMBER_COLUMN = verify_temp_column_name(
                sdf, "__correlation_spearman_row_number_temp_column__"
            )
            DENSE_RANK_COLUMN = verify_temp_column_name(
                sdf, "__correlation_spearman_dense_rank_temp_column__"
            )
            window = Window.partitionBy(groupKeys)

            # CORRELATION_VALUE_1_COLUMN: value -> avg rank
            # for example:
            # values:       3, 4, 5, 7, 7, 7, 9, 9, 10
            # avg ranks:    1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0
            sdf = (
                sdf.withColumn(
                    ROW_NUMBER_COLUMN,
                    F.row_number().over(
                        window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN))
                    ),
                )
                # drop nulls but make sure each group contains at least one row
                .where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_COLUMN) == 1))
                .withColumn(
                    DENSE_RANK_COLUMN,
                    F.dense_rank().over(
                        window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN))
                    ),
                )
                .withColumn(
                    CORRELATION_VALUE_1_COLUMN,
                    F.when(F.isnull(CORRELATION_VALUE_1_COLUMN), F.lit(None)).otherwise(
                        F.avg(ROW_NUMBER_COLUMN).over(
                            window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0)
                        )
                    ),
                )
            )

            # CORRELATION_VALUE_2_COLUMN: value -> avg rank
            sdf = (
                sdf.withColumn(
                    ROW_NUMBER_COLUMN,
                    F.row_number().over(
                        window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN))
                    ),
                )
                .withColumn(
                    DENSE_RANK_COLUMN,
                    F.dense_rank().over(
                        window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN))
                    ),
                )
                .withColumn(
                    CORRELATION_VALUE_2_COLUMN,
                    F.when(F.isnull(CORRELATION_VALUE_2_COLUMN), F.lit(None)).otherwise(
                        F.avg(ROW_NUMBER_COLUMN).over(
                            window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0)
                        )
                    ),
                )
            )

        sdf = sdf.groupby(groupKeys).agg(
            F.corr(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN).alias(
                CORRELATION_CORR_OUTPUT_COLUMN
            ),
            F.count(
                F.when(
                    ~F.isnull(CORRELATION_VALUE_1_COLUMN),
                    1,
                )
            ).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
        )

        return sdf

    else:
        # kendall correlation
        ROW_NUMBER_1_2_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_row_number_1_2_temp_column__"
        )
        sdf = sdf.withColumn(
            ROW_NUMBER_1_2_COLUMN,
            F.row_number().over(
                Window.partitionBy(groupKeys).orderBy(
                    F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN),
                    F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN),
                )
            ),
        )

        # drop nulls but make sure each group contains at least one row
        sdf = sdf.where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_1_2_COLUMN) == 1))

        CORRELATION_VALUE_X_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_value_x_temp_column__"
        )
        CORRELATION_VALUE_Y_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_value_y_temp_column__"
        )
        ROW_NUMBER_X_Y_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_row_number_x_y_temp_column__"
        )
        sdf2 = sdf.select(
            *[F.col(key) for key in groupKeys],
            *[
                F.col(CORRELATION_VALUE_1_COLUMN).alias(CORRELATION_VALUE_X_COLUMN),
                F.col(CORRELATION_VALUE_2_COLUMN).alias(CORRELATION_VALUE_Y_COLUMN),
                F.col(ROW_NUMBER_1_2_COLUMN).alias(ROW_NUMBER_X_Y_COLUMN),
            ],
        )

        sdf = sdf.join(sdf2, groupKeys, "inner").where(
            F.col(ROW_NUMBER_1_2_COLUMN) <= F.col(ROW_NUMBER_X_Y_COLUMN)
        )

        # compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + Q + U))
        # see https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222
        CORRELATION_KENDALL_P_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_tau_b_p_temp_column__"
        )
        CORRELATION_KENDALL_Q_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_tau_b_q_temp_column__"
        )
        CORRELATION_KENDALL_T_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_tau_b_t_temp_column__"
        )
        CORRELATION_KENDALL_U_COLUMN = verify_temp_column_name(
            sdf, "__correlation_kendall_tau_b_u_temp_column__"
        )

        pair_cond = ~F.isnull(CORRELATION_VALUE_1_COLUMN) & (
            F.col(ROW_NUMBER_1_2_COLUMN) < F.col(ROW_NUMBER_X_Y_COLUMN)
        )

        p_cond = (
            (F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN))
            & (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN))
        ) | (
            (F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN))
            & (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN))
        )
        q_cond = (
            (F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN))
            & (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN))
        ) | (
            (F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN))
            & (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN))
        )
        t_cond = (F.col(CORRELATION_VALUE_1_COLUMN) == F.col(CORRELATION_VALUE_X_COLUMN)) & (
            F.col(CORRELATION_VALUE_2_COLUMN) != F.col(CORRELATION_VALUE_Y_COLUMN)
        )
        u_cond = (F.col(CORRELATION_VALUE_1_COLUMN) != F.col(CORRELATION_VALUE_X_COLUMN)) & (
            F.col(CORRELATION_VALUE_2_COLUMN) == F.col(CORRELATION_VALUE_Y_COLUMN)
        )

        sdf = (
            sdf.groupby(groupKeys)
            .agg(
                F.count(F.when(pair_cond & p_cond, 1)).alias(CORRELATION_KENDALL_P_COLUMN),
                F.count(F.when(pair_cond & q_cond, 1)).alias(CORRELATION_KENDALL_Q_COLUMN),
                F.count(F.when(pair_cond & t_cond, 1)).alias(CORRELATION_KENDALL_T_COLUMN),
                F.count(F.when(pair_cond & u_cond, 1)).alias(CORRELATION_KENDALL_U_COLUMN),
                F.max(
                    F.when(
                        ~F.isnull(CORRELATION_VALUE_1_COLUMN), F.col(ROW_NUMBER_X_Y_COLUMN)
                    ).otherwise(F.lit(0))
                ).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
            )
            .withColumn(
                CORRELATION_CORR_OUTPUT_COLUMN,
                (F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN))
                / F.sqrt(
                    (
                        (
                            F.col(CORRELATION_KENDALL_P_COLUMN)
                            + F.col(CORRELATION_KENDALL_Q_COLUMN)
                            + (F.col(CORRELATION_KENDALL_T_COLUMN))
                        )
                    )
                    * (
                        (
                            F.col(CORRELATION_KENDALL_P_COLUMN)
                            + F.col(CORRELATION_KENDALL_Q_COLUMN)
                            + (F.col(CORRELATION_KENDALL_U_COLUMN))
                        )
                    )
                ),
            )
        )

        sdf = sdf.select(
            *[F.col(key) for key in groupKeys],
            *[CORRELATION_CORR_OUTPUT_COLUMN, CORRELATION_COUNT_OUTPUT_COLUMN],
        )
        return sdf