in quant/binary/optimal.py [0:0]
def cost_function(matrix: torch.Tensor, v1s: torch.Tensor, ternary: bool = False) -> torch.Tensor:
"""
Compute the cost function to find the optimal v1.
The cost function is equation (8) in the paper, for k=2.
It can be derived by expanding s1, s2 using the foldable quantization equation (9).
Args:
matrix: original 2D tensor
v1s: 2D tensor containing potential optimal solutions
ternary: compute cost for ternary function
Returns:
Norms as a 2D tensor
"""
matrix_view = matrix.view(matrix.shape[0], 1, -1)
v1s_view = v1s.view(v1s.shape[0], v1s.shape[1], 1)
s2_arg = matrix_view - v1s_view * binary_sign(matrix_view)
if ternary:
v2 = v1s_view
else:
v2 = s2_arg.abs().mean(dim=-1, keepdim=True)
return torch.norm(s2_arg - v2 * binary_sign(s2_arg), dim=-1) # type: ignore