def generate_random_noise()

in privacy_budget_tracking/zcdp_tracker.py [0:0]


    def generate_random_noise(key, shape, sampling_function=random.normal):
        # This function samples from a probability distribution using sampling
        # function, with PRNG key and returns it with specified size. The sampling
        # function MUST take as arguments key and shape
        # By default, this function will use the normal distribution
        # Read: better random numbers generation with subkeys
        # https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#JAX-PRNG

        compiled_sampler = jit(sampling_function, static_argnums=1)

        if len(shape) < 2:
            # no need to split the key for low-dimensional statistic
            return compiled_sampler(key, shape)

        # generate subkeys
        subkeys = random.split(key, shape[0])
        return vmap(compiled_sampler, in_axes=(0, None), out_axes=0)(subkeys, shape[1:])