def pack_weights()

in optimum/quanto/tensor/packed.py [0:0]


def pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor:
    """
    Pack int4 / int2 weights in a uint8 tensor

    What packing means? Assume we have 4 values that are in 2bit but encoded in 8bit
    (because torch does not have native support for 2-bit datatypes)

    > 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000

    We can pack them in a single 8-bit uint value

    > 1110 0100

    Therefore instead of saving 4 values in 8-bit precision we save a single value of 8-bit precision saving 24 bits in total.

    Args:
        intweights (`torch.Tensor`):
            The un-packed `torch.uint8` tensor
        bits (`int`):
            The actual `bits` - can be 2, 4
    """
    original_shape = intweights.shape
    values_per_item = 8 // bits
    row_dim = (original_shape[0] + values_per_item - 1) // values_per_item

    if len(original_shape) == 1:
        packed_tensor_shape = (row_dim,)
    else:
        packed_tensor_shape = (row_dim, *original_shape[1:])

    packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
    unpacked = intweights.to(torch.uint8)

    def lshift(t: torch.Tensor, bits: int):
        if t.device.type == "mps":
            # lshift is not supported on MPS device
            return t * (2**bits)
        return t << bits

    it = min(values_per_item, (original_shape[0] // row_dim) + 1)
    for i in range(it):
        start = i * row_dim
        end = min(start + row_dim, original_shape[0])
        packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)

    return packed