in diffq/ts_export.py [0:0]
def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
"""Function called from TorchScript on the first forward to decode the
packed weights to FP32.
"""
packed_all_levels, scales, packed_bits, shape = packed
numel = 1
for dim in shape:
numel *= dim
bits = unpack(packed_bits, numel // group_size) + min_bits
levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
for idx, packed_levels in enumerate(packed_all_levels):
bit = idx + 1
if packed_levels is not None:
sub_levels = levels[bits == bit]
levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
bits = bits[:, None]
unquant = uniform_unquantize(levels, scales, bits)
if len(shape) == 4:
return unquant.view(shape[0], shape[1], shape[2], shape[3])
elif len(shape) == 3:
return unquant.view(shape[0], shape[1], shape[2])
elif len(shape) == 2:
return unquant.view(shape[0], shape[1])
elif len(shape) == 1:
return unquant.view(shape[0])
else:
raise RuntimeError("Invalid numbr of dim")