in src/rime/util/__init__.py [0:0]
def _argsort(S, tie_breaker=1e-10, device="cpu"):
print(f"_argsort {S.size:,} scores on device {device}; ", end="")
if hasattr(S, "batch_size") and S.batch_size < S.shape[0]:
warnings.warn(f"switching numpy.argsort due to {S.batch_size}<{S.shape[0]}")
device = None
if hasattr(S, "eval"):
S = S.eval(device)
shape = S.shape
if device is None:
if tie_breaker > 0:
S = S + np.random.rand(*S.shape) * tie_breaker
S = -S.reshape(-1)
_empty_cache()
argsort_ind = np.argsort(S)
else:
S = torch.as_tensor(S, device=device)
if tie_breaker > 0:
S = S + torch.rand(*S.shape, device=device) * tie_breaker
S = -S.reshape(-1)
_empty_cache()
argsort_ind = torch.argsort(S).cpu().numpy()
return np.unravel_index(argsort_ind, shape)