in quant/binary/optimal.py [0:0]
def compute_mask(matrix: torch.Tensor, ternary: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Compute mask for a 2D tensor of absolute values.
The mask reveals potential optimal values.
Args:
matrix: A 2D tensor of absolute values.
ternary: whether we are computing mask for ternary algorithm
Returns:
A 2-tuple of tensors, where the first element is a mask
tensor and the second element are values selected
"""
values, _ = torch.sort(matrix, dim=1)
cum_sums = values.cumsum(dim=1)
# store counts of elements at the corresponding position
counts = torch.arange(1, matrix.shape[1] + 1, device=matrix.device)
counts_rev = torch.flip(counts, [0]) - 1
counts_rev[-1] = 1 # avoid division by 0, value at this pos. will not be used
m1s = None
if not ternary:
# m1s stores cumulative means from left to right (chopping left and right most values)
m1s = (cum_sums / counts)[:, 1:-1]
# m2s stores cumulative means from right to left (chopping left and right most values)
m2s = ((cum_sums[:, -1:] - cum_sums) / counts_rev)[:, 1:-1]
# re-using m1s and m2s to save memory
# using m1s and m2s values to find potential optimal solutions to v1 and v2
if not ternary:
m1s = 0.5 * (m1s + m2s)
m2s = 0.5 * m2s
# Find potential solutions in inner region and boundary
# Instead of finding equality, find index where m1s or m2s
# is >= than everything on the left and <= than everything on the right
mask = (values[:, 1:-1] <= m2s) * (m2s <= values[:, 2:])
if not ternary:
mask = mask + (values[:, 1:-1] <= m1s) * (m1s <= values[:, 2:])
masked_vs = torch.masked_select(values[:, 1:-1], mask)
return mask, masked_vs