in python/pyspark/pandas/correlation.py [0:0]
def compute(sdf: SparkDataFrame, groupKeys: List[str], method: str) -> SparkDataFrame:
"""
Compute correlation per group, excluding NA/null values.
Input PySpark Dataframe should contain column `CORRELATION_VALUE_1_COLUMN` and
column `CORRELATION_VALUE_2_COLUMN`, as well as the group columns.
The returned PySpark Dataframe will contain the correlation column
`CORRELATION_CORR_OUTPUT_COLUMN` and the non-null count column
`CORRELATION_COUNT_OUTPUT_COLUMN`, as well as the group columns.
"""
assert len(groupKeys) > 0
assert method in ["pearson", "spearman", "kendall"]
sdf = sdf.select(
*[F.col(key) for key in groupKeys],
*[
# assign both columns nulls, if some of them are null
F.when(
F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN),
F.lit(None),
)
.otherwise(F.col(CORRELATION_VALUE_1_COLUMN))
.alias(CORRELATION_VALUE_1_COLUMN),
F.when(
F.isnull(CORRELATION_VALUE_1_COLUMN) | F.isnull(CORRELATION_VALUE_2_COLUMN),
F.lit(None),
)
.otherwise(F.col(CORRELATION_VALUE_2_COLUMN))
.alias(CORRELATION_VALUE_2_COLUMN),
],
)
if method in ["pearson", "spearman"]:
# convert values to avg ranks for spearman correlation
if method == "spearman":
ROW_NUMBER_COLUMN = verify_temp_column_name(
sdf, "__correlation_spearman_row_number_temp_column__"
)
DENSE_RANK_COLUMN = verify_temp_column_name(
sdf, "__correlation_spearman_dense_rank_temp_column__"
)
window = Window.partitionBy(groupKeys)
# CORRELATION_VALUE_1_COLUMN: value -> avg rank
# for example:
# values: 3, 4, 5, 7, 7, 7, 9, 9, 10
# avg ranks: 1.0, 2.0, 3.0, 5.0, 5.0, 5.0, 7.5, 7.5, 9.0
sdf = (
sdf.withColumn(
ROW_NUMBER_COLUMN,
F.row_number().over(
window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN))
),
)
# drop nulls but make sure each group contains at least one row
.where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_COLUMN) == 1))
.withColumn(
DENSE_RANK_COLUMN,
F.dense_rank().over(
window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN))
),
)
.withColumn(
CORRELATION_VALUE_1_COLUMN,
F.when(F.isnull(CORRELATION_VALUE_1_COLUMN), F.lit(None)).otherwise(
F.avg(ROW_NUMBER_COLUMN).over(
window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0)
)
),
)
)
# CORRELATION_VALUE_2_COLUMN: value -> avg rank
sdf = (
sdf.withColumn(
ROW_NUMBER_COLUMN,
F.row_number().over(
window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN))
),
)
.withColumn(
DENSE_RANK_COLUMN,
F.dense_rank().over(
window.orderBy(F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN))
),
)
.withColumn(
CORRELATION_VALUE_2_COLUMN,
F.when(F.isnull(CORRELATION_VALUE_2_COLUMN), F.lit(None)).otherwise(
F.avg(ROW_NUMBER_COLUMN).over(
window.orderBy(F.asc(DENSE_RANK_COLUMN)).rangeBetween(0, 0)
)
),
)
)
sdf = sdf.groupby(groupKeys).agg(
F.corr(CORRELATION_VALUE_1_COLUMN, CORRELATION_VALUE_2_COLUMN).alias(
CORRELATION_CORR_OUTPUT_COLUMN
),
F.count(
F.when(
~F.isnull(CORRELATION_VALUE_1_COLUMN),
1,
)
).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
)
return sdf
else:
# kendall correlation
ROW_NUMBER_1_2_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_row_number_1_2_temp_column__"
)
sdf = sdf.withColumn(
ROW_NUMBER_1_2_COLUMN,
F.row_number().over(
Window.partitionBy(groupKeys).orderBy(
F.asc_nulls_last(CORRELATION_VALUE_1_COLUMN),
F.asc_nulls_last(CORRELATION_VALUE_2_COLUMN),
)
),
)
# drop nulls but make sure each group contains at least one row
sdf = sdf.where(~F.isnull(CORRELATION_VALUE_1_COLUMN) | (F.col(ROW_NUMBER_1_2_COLUMN) == 1))
CORRELATION_VALUE_X_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_value_x_temp_column__"
)
CORRELATION_VALUE_Y_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_value_y_temp_column__"
)
ROW_NUMBER_X_Y_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_row_number_x_y_temp_column__"
)
sdf2 = sdf.select(
*[F.col(key) for key in groupKeys],
*[
F.col(CORRELATION_VALUE_1_COLUMN).alias(CORRELATION_VALUE_X_COLUMN),
F.col(CORRELATION_VALUE_2_COLUMN).alias(CORRELATION_VALUE_Y_COLUMN),
F.col(ROW_NUMBER_1_2_COLUMN).alias(ROW_NUMBER_X_Y_COLUMN),
],
)
sdf = sdf.join(sdf2, groupKeys, "inner").where(
F.col(ROW_NUMBER_1_2_COLUMN) <= F.col(ROW_NUMBER_X_Y_COLUMN)
)
# compute P, Q, T, U in tau_b = (P - Q) / sqrt((P + Q + T) * (P + Q + U))
# see https://github.com/scipy/scipy/blob/v1.9.1/scipy/stats/_stats_py.py#L5015-L5222
CORRELATION_KENDALL_P_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_tau_b_p_temp_column__"
)
CORRELATION_KENDALL_Q_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_tau_b_q_temp_column__"
)
CORRELATION_KENDALL_T_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_tau_b_t_temp_column__"
)
CORRELATION_KENDALL_U_COLUMN = verify_temp_column_name(
sdf, "__correlation_kendall_tau_b_u_temp_column__"
)
pair_cond = ~F.isnull(CORRELATION_VALUE_1_COLUMN) & (
F.col(ROW_NUMBER_1_2_COLUMN) < F.col(ROW_NUMBER_X_Y_COLUMN)
)
p_cond = (
(F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN))
& (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN))
) | (
(F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN))
& (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN))
)
q_cond = (
(F.col(CORRELATION_VALUE_1_COLUMN) < F.col(CORRELATION_VALUE_X_COLUMN))
& (F.col(CORRELATION_VALUE_2_COLUMN) > F.col(CORRELATION_VALUE_Y_COLUMN))
) | (
(F.col(CORRELATION_VALUE_1_COLUMN) > F.col(CORRELATION_VALUE_X_COLUMN))
& (F.col(CORRELATION_VALUE_2_COLUMN) < F.col(CORRELATION_VALUE_Y_COLUMN))
)
t_cond = (F.col(CORRELATION_VALUE_1_COLUMN) == F.col(CORRELATION_VALUE_X_COLUMN)) & (
F.col(CORRELATION_VALUE_2_COLUMN) != F.col(CORRELATION_VALUE_Y_COLUMN)
)
u_cond = (F.col(CORRELATION_VALUE_1_COLUMN) != F.col(CORRELATION_VALUE_X_COLUMN)) & (
F.col(CORRELATION_VALUE_2_COLUMN) == F.col(CORRELATION_VALUE_Y_COLUMN)
)
sdf = (
sdf.groupby(groupKeys)
.agg(
F.count(F.when(pair_cond & p_cond, 1)).alias(CORRELATION_KENDALL_P_COLUMN),
F.count(F.when(pair_cond & q_cond, 1)).alias(CORRELATION_KENDALL_Q_COLUMN),
F.count(F.when(pair_cond & t_cond, 1)).alias(CORRELATION_KENDALL_T_COLUMN),
F.count(F.when(pair_cond & u_cond, 1)).alias(CORRELATION_KENDALL_U_COLUMN),
F.max(
F.when(
~F.isnull(CORRELATION_VALUE_1_COLUMN), F.col(ROW_NUMBER_X_Y_COLUMN)
).otherwise(F.lit(0))
).alias(CORRELATION_COUNT_OUTPUT_COLUMN),
)
.withColumn(
CORRELATION_CORR_OUTPUT_COLUMN,
(F.col(CORRELATION_KENDALL_P_COLUMN) - F.col(CORRELATION_KENDALL_Q_COLUMN))
/ F.sqrt(
(
(
F.col(CORRELATION_KENDALL_P_COLUMN)
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
+ (F.col(CORRELATION_KENDALL_T_COLUMN))
)
)
* (
(
F.col(CORRELATION_KENDALL_P_COLUMN)
+ F.col(CORRELATION_KENDALL_Q_COLUMN)
+ (F.col(CORRELATION_KENDALL_U_COLUMN))
)
)
),
)
)
sdf = sdf.select(
*[F.col(key) for key in groupKeys],
*[CORRELATION_CORR_OUTPUT_COLUMN, CORRELATION_COUNT_OUTPUT_COLUMN],
)
return sdf