in optimum/quanto/tensor/packed.py [0:0]
def pack_weights(intweights: torch.Tensor, bits: int) -> torch.Tensor:
"""
Pack int4 / int2 weights in a uint8 tensor
What packing means? Assume we have 4 values that are in 2bit but encoded in 8bit
(because torch does not have native support for 2-bit datatypes)
> 0000 0011 | 0000 0010 | 0000 0001 | 0000 0000
We can pack them in a single 8-bit uint value
> 1110 0100
Therefore instead of saving 4 values in 8-bit precision we save a single value of 8-bit precision saving 24 bits in total.
Args:
intweights (`torch.Tensor`):
The un-packed `torch.uint8` tensor
bits (`int`):
The actual `bits` - can be 2, 4
"""
original_shape = intweights.shape
values_per_item = 8 // bits
row_dim = (original_shape[0] + values_per_item - 1) // values_per_item
if len(original_shape) == 1:
packed_tensor_shape = (row_dim,)
else:
packed_tensor_shape = (row_dim, *original_shape[1:])
packed = torch.zeros(packed_tensor_shape, device=intweights.device, dtype=torch.uint8)
unpacked = intweights.to(torch.uint8)
def lshift(t: torch.Tensor, bits: int):
if t.device.type == "mps":
# lshift is not supported on MPS device
return t * (2**bits)
return t << bits
it = min(values_per_item, (original_shape[0] // row_dim) + 1)
for i in range(it):
start = i * row_dim
end = min(start + row_dim, original_shape[0])
packed[: (end - start)] |= lshift(unpacked[start:end], bits * i)
return packed