in statistickway.py [0:0]
def preserve_statistic(queries):
letters_idx = string.ascii_lowercase
row_idx, col_idxs = letters_idx[0], letters_idx[1:]
k = len(queries[0])
assert len(queries[0]) <= len(col_idxs)
# construct string for k-th order statistic with einsum, e.g., for 3-th marginal 'ab,ac,ad->bcd'
eins_string = ",".join([row_idx + a for a in col_idxs[:k]]) + "->" + col_idxs[:k]
@jit
def compute_statistic(D):
return (
np.concatenate(
[
np.einsum(eins_string, *[D[:, idx_q] for idx_q in q]).flatten()
for q in queries
]
)
/ D.shape[0]
)
return compute_statistic