in src/rime/util/cvx_bisect.py [0:0]
def dual_solve_u(v, s, alpha, eps, verbose=False, n_iters=100, gtol=0):
"""
min_{u>=0} max_pi L(pi, u, v)
= E_xy [ u(x)alpha(x) + v(y)beta(y) + Softplus(1/eps)(s-u-v) ],
where u = min{u>=0 : E_y[pi(x,y)] <= alpha(x)}
find exact u s.t. E_y[pi(x,y)] == alpha(x)
"""
alpha = torch.as_tensor(alpha, device=s.device).clip(0, 1)
eps = torch.as_tensor(eps, device=s.device)
z = alpha.log() - (1 - alpha).log()
if alpha.amax() <= 0 or alpha.amin() >= 1: # z = +-infinity
u = -z * torch.ones_like(s[:, 0])
return u, 0
v_inp = torch.as_tensor(v, device=s.device).reshape((1, -1))
if 'CVX_STABLE' in os.environ and int(os.environ['CVX_STABLE']):
v = v_inp
else:
s = s_u_v(s, None, v)
v = None
u_min = s_u_v(s, None, v).amin(1) - z * eps - 1e-3
u_max = s_u_v(s, None, v).amax(1) - z * eps + 1e-3
u_guess = [ # avoids large negative prior_score when s>=0 if most valid cases
torch.zeros_like(u_min) + (0 - v_inp).amin() - z * eps - 1e-3,
]
# u_guess.extend(
# s_u_v(s, None, v).topk(
# (alpha * s.shape[1] + 1).clip(None, s.shape[1]).int()
# ).values[:, -3:].T
# )
assert (grad_u(u_min, v, s, alpha, eps) <= 0).all()
assert (grad_u(u_max, v, s, alpha, eps) >= 0).all()
for i in range(n_iters):
if i < len(u_guess):
u = u_guess[i]
else:
u = (u_min + u_max) / 2
g = grad_u(u, v, s, alpha, eps)
assert not u.isnan().any()
if g.abs().max() < gtol:
break
u_min = torch.where(g < 0, u, u_min)
u_max = torch.where(g > 0, u, u_max)
return u, (i + 1)