def opt_v1()

in quant/binary/optimal.py [0:0]


def opt_v1(matrix: torch.Tensor, ternary: bool, skip: int = 1) -> torch.Tensor:  # type: ignore
    """
    Implement the algorithm to find v1 for least squares 2-bit and ternary algorithm.

    Args:
        matrix: A 2D tensor
        ternary: whether to do ternary optimization
        skip: increment in potential solution space to speed up computation

    Returns:
        Optimal v1
    """
    with torch.no_grad():
        matrix_skipped = matrix[..., ::skip].abs()
        mask, masked_vs = compute_mask(matrix_skipped, ternary)

        # masked_vs is a vector, we need to separate it into potential
        # optimal solutions by row (dim 0)
        split_sizes = mask.sum(dim=1)

        if ternary:
            # handle a special case for ternary that rarely occurs
            masked_vs, split_sizes = _handle_ternary_min_gt_half_avg(
                matrix_skipped, masked_vs, split_sizes
            )

        vs = torch.split(masked_vs, split_sizes.tolist())  # type: ignore
        vs = rnn_utils.pad_sequence(vs, batch_first=True)  # type: ignore

        costs = cost_function(matrix_skipped, vs, ternary)
        indices = torch.argmin(costs, dim=-1, keepdim=True)

        v1 = torch.gather(vs, 1, indices)

        return v1