in diffq/torch_pack.py [0:0]
def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16):
"""You can think of indexes as a "Tensor" of bits of shape [L, nbits].
Instead of concatenating naively as [L * nbits], we instead look at it transposed as
[nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store
efficiently on int16 integers.
There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large
model layers this is acceptable. Storage type can be changed.
`nbits` should be the number of bits on which the indexes are coded, and will
actually be determined automatically if set to 0.
"""
assert not indexes.dtype.is_floating_point
if indexes.numel() > 0:
assert indexes.max().item() < 2 ** 15
assert indexes.min().item() >= 0
if nbits == 0:
nbits = int(math.ceil(math.log2(1 + (indexes.max()))))
else:
assert indexes.max().item() < 2 ** nbits
indexes = indexes.reshape(-1)
storage_size = _storage_size(storage_dtype)
rect = as_rectangle(indexes, storage_size)
out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device)
for in_bit in range(nbits):
for out_bit in range(storage_size):
d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit
out[in_bit, :] |= d
return out