def log_prob()

in densities.py [0:0]


    def log_prob(self, x):
        # TODO: Could be optimized
        # TODO: Assumes x is uniformly distributed

        lonlat = utils.euclidean_to_spherical(x)
        s = jnp.pi/2-.2 # long side length

        def in_board(z, s):
            # z is lonlat
            lon = z[0]
            lat = z[1]

            if np.pi <= lon < np.pi+s or np.pi-2*s <= lon < np.pi-s:
                v = np.pi/2 <= lat < np.pi/2+s/2 or \
                    np.pi/2-s <= lat < np.pi/2-s/2
            elif np.pi-2*s <= lon < np.pi+2*s:
                v = np.pi/2+s/2 <= lat < np.pi/2+s or \
                    np.pi/2-s/2 <= lat < np.pi/2
            else:
                v = 0.

            v = float(v)
            return v

        probs = []
        for i in range(lonlat.shape[0]):
            probs.append(in_board(lonlat[i,:], s))
        probs = jnp.stack(probs)
        probs /= jnp.sum(probs)
        probs = jnp.log(probs)
        return probs