def check_p_conditions()

in causalml/inference/meta/utils.py [0:0]


def check_p_conditions(p, t_groups):
    eps = np.finfo(float).eps
    assert isinstance(
        p, (np.ndarray, pd.Series, dict)
    ), "p must be an np.ndarray, pd.Series, or dict type"
    if isinstance(p, (np.ndarray, pd.Series)):
        assert (
            t_groups.shape[0] == 1
        ), "If p is passed as an np.ndarray, there must be only 1 unique non-control group in the treatment vector."
        assert (0 + eps < p).all() and (
            p < 1 - eps
        ).all(), "The values of p should lie within the (0, 1) interval."

    if isinstance(p, dict):
        for t_name in t_groups:
            assert (0 + eps < p[t_name]).all() and (
                p[t_name] < 1 - eps
            ).all(), "The values of p should lie within the (0, 1) interval."