in densities.py [0:0]
def sample(self, key, n_samples):
keys = random.split(key, len(self.dists))
n = int(np.ceil(n_samples/len(self.dists)))
samples = jnp.concatenate([
d.sample(key, n) for key, d in zip(keys, self.dists)
], axis=0)
samples = random.permutation(key, samples)
return samples[:n_samples]