def dual_solve_u()

in src/rime/util/cvx_bisect.py [0:0]


def dual_solve_u(v, s, alpha, eps, verbose=False, n_iters=100, gtol=0):
    """
    min_{u>=0} max_pi L(pi, u, v)
        = E_xy [ u(x)alpha(x) + v(y)beta(y) + Softplus(1/eps)(s-u-v) ],
            where u = min{u>=0 : E_y[pi(x,y)] <= alpha(x)}
    find exact u s.t. E_y[pi(x,y)] == alpha(x)
    """
    alpha = torch.as_tensor(alpha, device=s.device).clip(0, 1)
    eps = torch.as_tensor(eps, device=s.device)

    z = alpha.log() - (1 - alpha).log()

    if alpha.amax() <= 0 or alpha.amin() >= 1:  # z = +-infinity
        u = -z * torch.ones_like(s[:, 0])
        return u, 0

    v_inp = torch.as_tensor(v, device=s.device).reshape((1, -1))
    if 'CVX_STABLE' in os.environ and int(os.environ['CVX_STABLE']):
        v = v_inp
    else:
        s = s_u_v(s, None, v)
        v = None

    u_min = s_u_v(s, None, v).amin(1) - z * eps - 1e-3
    u_max = s_u_v(s, None, v).amax(1) - z * eps + 1e-3

    u_guess = [  # avoids large negative prior_score when s>=0 if most valid cases
        torch.zeros_like(u_min) + (0 - v_inp).amin() - z * eps - 1e-3,
    ]
    # u_guess.extend(
    #     s_u_v(s, None, v).topk(
    #         (alpha * s.shape[1] + 1).clip(None, s.shape[1]).int()
    #         ).values[:, -3:].T
    # )

    assert (grad_u(u_min, v, s, alpha, eps) <= 0).all()
    assert (grad_u(u_max, v, s, alpha, eps) >= 0).all()

    for i in range(n_iters):
        if i < len(u_guess):
            u = u_guess[i]
        else:
            u = (u_min + u_max) / 2
        g = grad_u(u, v, s, alpha, eps)
        assert not u.isnan().any()
        if g.abs().max() < gtol:
            break
        u_min = torch.where(g < 0, u, u_min)
        u_max = torch.where(g > 0, u, u_max)

    return u, (i + 1)