def kl_divergence()

in causalml/inference/tree/uplift.pyx [0:0]


def kl_divergence(pk: cython.float, qk: cython.float) -> cython.float:
    '''
    Calculate KL Divergence for binary classification.

    sum(np.array(pk) * np.log(np.array(pk) / np.array(qk)))

    Args
    ----
    pk : float
        The probability of 1 in one distribution.
    qk : float
        The probability of 1 in the other distribution.

    Returns
    -------
    S : float
        The KL divergence.
    '''

    eps: cython.float = 1e-6
    S: cython.float

    if qk == 0.:
        return 0.

    qk = min(max(qk, eps), 1 - eps)

    if pk == 0.:
        S = -log(1 - qk)
    elif pk == 1.:
        S = -log(qk)
    else:
        S = pk * log(pk / qk) + (1 - pk) * log((1 - pk) / (1 - qk))

    return S