in src/smclarify/bias/metrics/pretraining.py [0:0]
def JS(label: pd.Series, sensitive_facet_index: pd.Series) -> float:
r"""
Jensen-Shannon Divergence (JS)
.. math::
JS(Pa, Pd, P) = 0.5 [KL(Pa,P) + KL(Pd,P)] \geq 0
:param label: column of labels
:param sensitive_facet_index: boolean column indicating sensitive group
:return: Jensen-Shannon (JS) divergence metric
"""
require(sensitive_facet_index.dtype == bool, "sensitive_facet_index must be of type bool")
xs_a = label[~sensitive_facet_index]
xs_d = label[sensitive_facet_index]
(Pa, Pd) = pdfs_aligned_nonzero(xs_a, xs_d)
P = 1 / 2 * (Pa + Pd)
if len(Pa) == 0 or len(Pd) == 0 or len(P) == 0:
raise ValueError("No instance of common facet found, dataset may be too small")
return 0.5 * (np.sum(Pa * np.log(Pa / P)) + np.sum(Pd * np.log(Pd / P)))