in manifolds.py [0:0]
def tangent_orthonormal_basis(self, x, dF):
assert x.ndim == 2
if x.shape[1] == 2:
E = x[:, jnp.array([1,0])] * jnp.array([-1., 1.])
E = E.reshape(*E.shape, 1)
elif x.shape[1] == 3:
# The potential's Riemannian derivative dF is on the
# tangent space, so on S2 we normalize this and
# find the only remaining orthogonal direction.
norm_v = dF / jnp.linalg.norm(dF, axis=-1, keepdims=True)
E = jnp.dstack([norm_v, jnp.cross(x, norm_v)])
else:
raise NotImplementedError()
return E