def preserve_statistic()

in statistickway_threshold.py [0:0]


def preserve_statistic(queries):
    # hardcoding threshold to 1 here

    letters_idx = string.ascii_lowercase
    row_idx, col_idxs = letters_idx[0], letters_idx[1:]

    assert len(queries[0]) <= len(col_idxs)

    # construct string for k-th order statistic with einsum, e.g., for 3-th marginal 'ab,ac,ad->bcd'
    eins_string = "ij,ik,il->jkl"

    @jit
    def compute_statistic(D):
        return (
            np.concatenate(
                [
                    D.shape[0]
                    - np.einsum(
                        eins_string, *[(1 - D[:, idx_q]) for idx_q in q]
                    ).flatten()
                    for q in queries
                ]
            )
            / D.shape[0]
        )

    return compute_statistic