def get_column_permutation()

in optimum/quanto/tensor/weights/marlin/fp8/packed.py [0:0]


def get_column_permutation(n_col: int) -> torch.Tensor:
    """
    Gets the column mapping to map from marlin-repacked weights to natural order.

    The natural order to marlin is: `8 * rest + frac` to `rest + 32 * frac`, by blocks of 256 values.
    """
    tile_size = 256
    n_blocks = n_col // tile_size

    a = torch.arange(tile_size)
    rest = a % 8
    frac = a // 8

    original_index = 32 * rest + frac

    original_index = torch.arange(n_blocks)[:, None] * 256 + original_index
    original_index = original_index.reshape(-1)

    # The mapping per-column is:
    #
    #      64   64   64   64      64   64   64   64       64   64   64   64
    # ------------------------------------------------------------------------
    # |    0    1    2    3  |    0    1    2    3   |    0    1    2    3   |
    # ------------------------------------------------------------------------
    #
    # Hence to retrieve column 0, 1, 2, 3 in order, we need to
    # shuffle the blocks of 64 values.
    original_index = original_index.reshape(4 * n_blocks, 64)

    # Generate a shuffling as e.g. [0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11] for the above.
    tmp1 = torch.arange(4)
    tmp1 = tmp1.repeat(n_blocks, 1).T.reshape(-1)  # e.g. [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3]

    tmp2 = torch.arange(n_blocks) * 4
    tmp2 = tmp2.repeat(4)  # e.g. [0, 4, 8, 0, 4, 8, 0, 4, 8, 0, 4, 8]

    remap_col_index = tmp1 + tmp2

    original_index = original_index[remap_col_index]
    original_index = original_index.reshape(-1)

    return original_index