in src/rime/util/score_array.py [0:0]
def eval(self, device=None):
if device is None:
z = self.ind_logits @ self.col_logits.T
assert not np.isnan(z).any(), "low rank score must be valid"
if self.act == 'exp':
return np.exp(z)
elif self.act == 'softplus':
with warnings.catch_warnings():
warnings.filterwarnings('ignore', 'overflow encountered in exp')
return np.where(z > 0, z + np.log(1 + np.exp(-z)), np.log(1 + np.exp(z)))
elif self.act == 'sigmoid':
return 1. / (1 + np.exp(-z))
elif self.act == '_nnmf':
return z
else:
raise NotImplementedError
else:
ind_logits = torch.as_tensor(self.ind_logits, device=device)
col_logits = torch.as_tensor(self.col_logits, device=device)
z = ind_logits @ col_logits.T
assert not torch.isnan(z).any(), "low rank score must be valid"
if self.act == 'exp':
return z.exp()
elif self.act == 'softplus':
return torch.nn.functional.softplus(z)
elif self.act == 'sigmoid':
return z.sigmoid()
elif self.act == '_nnmf':
return z
else:
raise NotImplementedError