def pack()

in diffq/torch_pack.py [0:0]


def pack(indexes, nbits: int = 0, storage_dtype: torch.dtype = torch.int16):
    """You can think of indexes as a "Tensor" of bits of shape [L, nbits].
    Instead of concatenating naively as [L * nbits], we instead look at it transposed as
    [nbits, L]. For L = 16 * G, we get [nbits, G, 16] which is trivial to store
    efficiently on int16 integers.
    There will be overhead if L is far from a multiple of 16 (e.g. 1) but for large
    model layers this is acceptable. Storage type can be changed.

    `nbits` should be the number of bits on which the indexes are coded, and will
    actually be determined automatically if set to 0.
    """
    assert not indexes.dtype.is_floating_point
    if indexes.numel() > 0:
        assert indexes.max().item() < 2 ** 15
        assert indexes.min().item() >= 0
        if nbits == 0:
            nbits = int(math.ceil(math.log2(1 + (indexes.max()))))
        else:
            assert indexes.max().item() < 2 ** nbits

    indexes = indexes.reshape(-1)
    storage_size = _storage_size(storage_dtype)
    rect = as_rectangle(indexes, storage_size)
    out = torch.zeros(nbits, rect.shape[1], dtype=storage_dtype, device=indexes.device)
    for in_bit in range(nbits):
        for out_bit in range(storage_size):
            d = ((rect[out_bit] >> in_bit) & 1).to(out.dtype) << out_bit
            out[in_bit, :] |= d
    return out