in hype/manifolds/lorentz.py [0:0]
def expm(self, p, d_p, lr=None, out=None, normalize=False):
"""Exponential map for hyperboloid"""
if out is None:
out = p
if d_p.is_sparse:
ix, d_val = d_p._indices().squeeze(), d_p._values()
# This pulls `ix` out of the original embedding table, which could
# be in a corrupted state. normalize it to fix it back to the
# surface of the hyperboloid...
# TODO: we should only do the normalize if we know that we are
# training with multiple threads, otherwise this is a bit wasteful
p_val = self.normalize(p.index_select(0, ix))
ldv = self.ldot(d_val, d_val, keepdim=True)
if self.debug:
assert all(ldv > 0), "Tangent norm must be greater 0"
assert all(ldv == ldv), "Tangent norm includes NaNs"
nd_p = ldv.clamp_(min=0).sqrt_()
t = th.clamp(nd_p, max=self.norm_clip)
nd_p.clamp_(min=self.eps)
newp = (th.cosh(t) * p_val).addcdiv_(th.sinh(t) * d_val, nd_p)
if normalize:
newp = self.normalize(newp)
p.index_copy_(0, ix, newp)
else:
if lr is not None:
d_p.narrow(-1, 0, 1).mul_(-1)
d_p.addcmul_((self.ldot(p, d_p, keepdim=True)).expand_as(p), p)
d_p.mul_(-lr)
ldv = self.ldot(d_p, d_p, keepdim=True)
if self.debug:
assert all(ldv > 0), "Tangent norm must be greater 0"
assert all(ldv == ldv), "Tangent norm includes NaNs"
nd_p = ldv.clamp_(min=0).sqrt_()
t = th.clamp(nd_p, max=self.norm_clip)
nd_p.clamp_(min=self.eps)
newp = (th.cosh(t) * p).addcdiv_(th.sinh(t) * d_p, nd_p)
if normalize:
newp = self.normalize(newp)
p.copy_(newp)