def compute_mask()

in quant/binary/optimal.py [0:0]


def compute_mask(matrix: torch.Tensor, ternary: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Compute mask for a 2D tensor of absolute values.

    The mask reveals potential optimal values.

    Args:
        matrix: A 2D tensor of absolute values.
        ternary: whether we are computing mask for ternary algorithm

    Returns:
        A 2-tuple of tensors, where the first element is a mask
        tensor and the second element are values selected
    """
    values, _ = torch.sort(matrix, dim=1)
    cum_sums = values.cumsum(dim=1)

    # store counts of elements at the corresponding position
    counts = torch.arange(1, matrix.shape[1] + 1, device=matrix.device)
    counts_rev = torch.flip(counts, [0]) - 1
    counts_rev[-1] = 1  # avoid division by 0, value at this pos. will not be used

    m1s = None
    if not ternary:
        # m1s stores cumulative means from left to right (chopping left and right most values)
        m1s = (cum_sums / counts)[:, 1:-1]
    # m2s stores cumulative means from right to left (chopping left and right most values)
    m2s = ((cum_sums[:, -1:] - cum_sums) / counts_rev)[:, 1:-1]

    # re-using m1s and m2s to save memory
    # using m1s and m2s values to find potential optimal solutions to v1 and v2
    if not ternary:
        m1s = 0.5 * (m1s + m2s)
    m2s = 0.5 * m2s
    # Find potential solutions in inner region and boundary
    # Instead of finding equality, find index where m1s or m2s
    # is >= than everything on the left and <= than everything on the right
    mask = (values[:, 1:-1] <= m2s) * (m2s <= values[:, 2:])
    if not ternary:
        mask = mask + (values[:, 1:-1] <= m1s) * (m1s <= values[:, 2:])

    masked_vs = torch.masked_select(values[:, 1:-1], mask)
    return mask, masked_vs