def sample()

in densities.py [0:0]


    def sample(self, key, n_samples):
        keys = random.split(key, len(self.dists))
        n = int(np.ceil(n_samples/len(self.dists)))
        samples = jnp.concatenate([
            d.sample(key, n) for key, d in zip(keys, self.dists)
        ], axis=0)
        samples = random.permutation(key, samples)
        return samples[:n_samples]