def preserve_statistic()

in statistickway.py [0:0]


def preserve_statistic(queries):

    letters_idx = string.ascii_lowercase
    row_idx, col_idxs = letters_idx[0], letters_idx[1:]
    k = len(queries[0])

    assert len(queries[0]) <= len(col_idxs)

    # construct string for k-th order statistic with einsum, e.g., for 3-th marginal 'ab,ac,ad->bcd'
    eins_string = ",".join([row_idx + a for a in col_idxs[:k]]) + "->" + col_idxs[:k]

    @jit
    def compute_statistic(D):
        return (
            np.concatenate(
                [
                    np.einsum(eins_string, *[D[:, idx_q] for idx_q in q]).flatten()
                    for q in queries
                ]
            )
            / D.shape[0]
        )

    return compute_statistic