def cost_function()

in quant/binary/optimal.py [0:0]


def cost_function(matrix: torch.Tensor, v1s: torch.Tensor, ternary: bool = False) -> torch.Tensor:
    """
    Compute the cost function to find the optimal v1.

    The cost function is equation (8) in the paper, for k=2.
    It can be derived by expanding s1, s2 using the foldable quantization equation (9).

    Args:
        matrix: original 2D tensor
        v1s: 2D tensor containing potential optimal solutions
        ternary: compute cost for ternary function

    Returns:
        Norms as a 2D tensor
    """
    matrix_view = matrix.view(matrix.shape[0], 1, -1)
    v1s_view = v1s.view(v1s.shape[0], v1s.shape[1], 1)
    s2_arg = matrix_view - v1s_view * binary_sign(matrix_view)
    if ternary:
        v2 = v1s_view
    else:
        v2 = s2_arg.abs().mean(dim=-1, keepdim=True)
    return torch.norm(s2_arg - v2 * binary_sign(s2_arg), dim=-1)  # type: ignore