in statistickway_threshold.py [0:0]
def preserve_statistic(queries):
# hardcoding threshold to 1 here
letters_idx = string.ascii_lowercase
row_idx, col_idxs = letters_idx[0], letters_idx[1:]
assert len(queries[0]) <= len(col_idxs)
# construct string for k-th order statistic with einsum, e.g., for 3-th marginal 'ab,ac,ad->bcd'
eins_string = "ij,ik,il->jkl"
@jit
def compute_statistic(D):
return (
np.concatenate(
[
D.shape[0]
- np.einsum(
eins_string, *[(1 - D[:, idx_q]) for idx_q in q]
).flatten()
for q in queries
]
)
/ D.shape[0]
)
return compute_statistic