def tangent_orthonormal_basis()

in manifolds.py [0:0]


    def tangent_orthonormal_basis(self, x, dF):
        assert x.ndim == 2

        if x.shape[1] == 2:
            E = x[:, jnp.array([1,0])] * jnp.array([-1., 1.])
            E = E.reshape(*E.shape, 1)
        elif x.shape[1] == 3:
            # The potential's Riemannian derivative dF is on the
            # tangent space, so on S2 we normalize this and
            # find the only remaining orthogonal direction.
            norm_v = dF / jnp.linalg.norm(dF, axis=-1, keepdims=True)
            E = jnp.dstack([norm_v, jnp.cross(x, norm_v)])
        else:
            raise NotImplementedError()

        return E