in privacy_budget_tracking/zcdp_tracker.py [0:0]
def generate_random_noise(key, shape, sampling_function=random.normal):
# This function samples from a probability distribution using sampling
# function, with PRNG key and returns it with specified size. The sampling
# function MUST take as arguments key and shape
# By default, this function will use the normal distribution
# Read: better random numbers generation with subkeys
# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG
compiled_sampler = jit(sampling_function, static_argnums=1)
if len(shape) < 2:
# no need to split the key for low-dimensional statistic
return compiled_sampler(key, shape)
# generate subkeys
subkeys = random.split(key, shape[0])
return vmap(compiled_sampler, in_axes=(0, None), out_axes=0)(subkeys, shape[1:])