in densities.py [0:0]
def log_prob(self, x):
# TODO: Could be optimized
# TODO: Assumes x is uniformly distributed
lonlat = utils.euclidean_to_spherical(x)
s = jnp.pi/2-.2 # long side length
def in_board(z, s):
# z is lonlat
lon = z[0]
lat = z[1]
if np.pi <= lon < np.pi+s or np.pi-2*s <= lon < np.pi-s:
v = np.pi/2 <= lat < np.pi/2+s/2 or \
np.pi/2-s <= lat < np.pi/2-s/2
elif np.pi-2*s <= lon < np.pi+2*s:
v = np.pi/2+s/2 <= lat < np.pi/2+s or \
np.pi/2-s/2 <= lat < np.pi/2
else:
v = 0.
v = float(v)
return v
probs = []
for i in range(lonlat.shape[0]):
probs.append(in_board(lonlat[i,:], s))
probs = jnp.stack(probs)
probs /= jnp.sum(probs)
probs = jnp.log(probs)
return probs