def sample()

in densities.py [0:0]


    def sample(self, key, n_samples):
        s = jnp.pi/2.-.2 # long side length
        offsets = jnp.array([
            (0,0), (s, s/2), (s, -s/2), (0, -s), (-s, s/2),
            (-s, -s/2), (-2*s, 0), (-2*s, -s)])

        # (x,y) ~ uniform([pi,pi + s] times [pi/2, pi/2 + s/2])
        k1, k2, k3 = jax.random.split(key, 3)
        x1 = random.uniform(k1, [n_samples]) * s + jnp.pi
        x2 = random.uniform(k1, [n_samples]) * s + jnp.pi
        x2 = random.uniform(k2, [n_samples]) * s/2. + jnp.pi/2.

        samples = jnp.stack([x1, x2], axis=1)
        off = offsets[random.randint(
            k3, [n_samples], minval=0, maxval=len(offsets))]

        samples += off

        samples = utils.spherical_to_euclidean(samples)
        return samples