in diffq/torch_pack.py [0:0]
def unpack(packed: torch.Tensor, length: tp.Optional[int] = None):
"""Opposite of `pack`. You might need to specify the original length."""
storage_size = _storage_size(packed.dtype)
nbits, groups = packed.shape
out = torch.zeros(storage_size, groups, dtype=torch.int16, device=packed.device)
for in_bit in range(storage_size):
for out_bit in range(nbits):
bit_value = (packed[out_bit, :] >> in_bit) & 1
out[in_bit, :] = out[in_bit, :] | (bit_value.to(out) << out_bit)
out = out.view(-1)
if length is not None:
out = out[:length]
return out