in optimum/quanto/tensor/weights/marlin/fp8/packed.py [0:0]
def get_column_permutation(n_col: int) -> torch.Tensor:
"""
Gets the column mapping to map from marlin-repacked weights to natural order.
The natural order to marlin is: `8 * rest + frac` to `rest + 32 * frac`, by blocks of 256 values.
"""
tile_size = 256
n_blocks = n_col // tile_size
a = torch.arange(tile_size)
rest = a % 8
frac = a // 8
original_index = 32 * rest + frac
original_index = torch.arange(n_blocks)[:, None] * 256 + original_index
original_index = original_index.reshape(-1)
# The mapping per-column is:
#
# 64 64 64 64 64 64 64 64 64 64 64 64
# ------------------------------------------------------------------------
# | 0 1 2 3 | 0 1 2 3 | 0 1 2 3 |
# ------------------------------------------------------------------------
#
# Hence to retrieve column 0, 1, 2, 3 in order, we need to
# shuffle the blocks of 64 values.
original_index = original_index.reshape(4 * n_blocks, 64)
# Generate a shuffling as e.g. [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] for the above.
tmp1 = torch.arange(4)
tmp1 = tmp1.repeat(n_blocks, 1).T.reshape(-1) # e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]
tmp2 = torch.arange(n_blocks) * 4
tmp2 = tmp2.repeat(4) # e.g. [0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8]
remap_col_index = tmp1 + tmp2
original_index = original_index[remap_col_index]
original_index = original_index.reshape(-1)
return original_index