in densities.py [0:0]
def sample(self, key, n_samples):
s = jnp.pi/2.-.2 # long side length
offsets = jnp.array([
(0,0), (s, s/2), (s, -s/2), (0, -s), (-s, s/2),
(-s, -s/2), (-2*s, 0), (-2*s, -s)])
# (x,y) ~ uniform([pi,pi + s] times [pi/2, pi/2 + s/2])
k1, k2, k3 = jax.random.split(key, 3)
x1 = random.uniform(k1, [n_samples]) * s + jnp.pi
x2 = random.uniform(k1, [n_samples]) * s + jnp.pi
x2 = random.uniform(k2, [n_samples]) * s/2. + jnp.pi/2.
samples = jnp.stack([x1, x2], axis=1)
off = offsets[random.randint(
k3, [n_samples], minval=0, maxval=len(offsets))]
samples += off
samples = utils.spherical_to_euclidean(samples)
return samples