def _bit_unpack_param()

in diffq/diffq.py [0:0]


    def _bit_unpack_param(self, qparam, packed, unpack_fn):
        """Unpack bitpacked representation. Should be overriden.
        """
        packed_all_levels, scales, packed_bits = packed
        bits = unpack_fn(packed_bits, qparam.logit.numel()) + self.min_bits
        bits = bits.to(qparam.param.device)
        levels = torch.empty(qparam.logit.numel(), self.group_size,
                             dtype=torch.short, device=qparam.param.device)
        for idx, packed_levels in enumerate(packed_all_levels):
            bit = idx + 1
            if packed_levels is None:
                continue
            sub_levels = levels[bits == bit]
            levels[bits == bit] = unpack_fn(
                packed_levels, sub_levels.numel()).view_as(sub_levels).to(sub_levels)
        return (levels, scales, bits)