def _unpack_param()

in diffq/ts_export.py [0:0]


def _unpack_param(packed: _DiffQPacked, group_size: int, min_bits: int) -> torch.Tensor:
    """Function called from TorchScript on the first forward to decode the
    packed weights to FP32.
    """
    packed_all_levels, scales, packed_bits, shape = packed
    numel = 1
    for dim in shape:
        numel *= dim
    bits = unpack(packed_bits, numel // group_size) + min_bits
    levels = torch.empty(bits.numel(), group_size, dtype=torch.short)
    for idx, packed_levels in enumerate(packed_all_levels):
        bit = idx + 1
        if packed_levels is not None:
            sub_levels = levels[bits == bit]
            levels[bits == bit] = unpack(packed_levels, sub_levels.numel()).view_as(sub_levels)
    bits = bits[:, None]
    unquant = uniform_unquantize(levels, scales, bits)
    if len(shape) == 4:
        return unquant.view(shape[0], shape[1], shape[2], shape[3])
    elif len(shape) == 3:
        return unquant.view(shape[0], shape[1], shape[2])
    elif len(shape) == 2:
        return unquant.view(shape[0], shape[1])
    elif len(shape) == 1:
        return unquant.view(shape[0])
    else:
        raise RuntimeError("Invalid numbr of dim")