in diffq/diffq.py [0:0]
def _bit_unpack_param(self, qparam, packed, unpack_fn):
"""Unpack bitpacked representation. Should be overriden.
"""
packed_all_levels, scales, packed_bits = packed
bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
bits = bits.to(qparam.param.device)
levels = torch.empty(qparam.logit.numel(), self.group_size,
dtype=torch.short, device=qparam.param.device)
for idx, packed_levels in enumerate(packed_all_levels):
bit = idx + 1
if packed_levels is None:
continue
sub_levels = levels[bits == bit]
levels[bits == bit] = unpack_fn(
packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
return (levels, scales, bits)