def KL()

in src/smclarify/bias/metrics/pretraining.py [0:0]


def KL(label: pd.Series, sensitive_facet_index: pd.Series) -> float:
    r"""
    Kullback-Liebler Divergence (KL)

    .. math::
        KL(Pa, Pd) = \sum_{x}{Pa(x) \ log \frac{Pa(x)}{Pd(x)}}

    :param label: column of labels
    :param sensitive_facet_index: boolean column indicating sensitive group
    :return: Kullback and Leibler (KL) divergence metric
    """
    require(sensitive_facet_index.dtype == bool, "sensitive_facet_index must be of type bool")
    xs_a = label[~sensitive_facet_index]
    xs_d = label[sensitive_facet_index]
    (Pa, Pd) = pdfs_aligned_nonzero(xs_a, xs_d)
    if len(Pa) == 0 or len(Pd) == 0:
        raise ValueError("No instance of common facet found, dataset may be too small")
    kl = np.sum(Pa * np.log(Pa / Pd))
    return kl