in quant/binary/optimal.py [0:0]
def opt_v1(matrix: torch.Tensor, ternary: bool, skip: int = 1) -> torch.Tensor: # type: ignore
"""
Implement the algorithm to find v1 for least squares 2-bit and ternary algorithm.
Args:
matrix: A 2D tensor
ternary: whether to do ternary optimization
skip: increment in potential solution space to speed up computation
Returns:
Optimal v1
"""
with torch.no_grad():
matrix_skipped = matrix[..., ::skip].abs()
mask, masked_vs = compute_mask(matrix_skipped, ternary)
# masked_vs is a vector, we need to separate it into potential
# optimal solutions by row (dim 0)
split_sizes = mask.sum(dim=1)
if ternary:
# handle a special case for ternary that rarely occurs
masked_vs, split_sizes = _handle_ternary_min_gt_half_avg(
matrix_skipped, masked_vs, split_sizes
)
vs = torch.split(masked_vs, split_sizes.tolist()) # type: ignore
vs = rnn_utils.pad_sequence(vs, batch_first=True) # type: ignore
costs = cost_function(matrix_skipped, vs, ternary)
indices = torch.argmin(costs, dim=-1, keepdim=True)
v1 = torch.gather(vs, 1, indices)
return v1