def _get_perm()

in optimum/quanto/tensor/weights/marlin/int4/packed.py [0:0]


def _get_perm():
    perm = []
    # 32 == # of threads in 1 warp
    for i in range(32):
        perm1 = []
        # column id in 16x8 weight block
        # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
        col = i // 4
        # 1 32bit (int32) == 8 4bit, 1 thread has 4 weights per 16x8 & 4bit weights are packed in int32, so needs 2 16x8 == 1 16x16 blocks
        for block in [0, 1]:
            # row id in 16x8 weight block
            # check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
            for row in [
                2 * (i % 4),
                2 * (i % 4) + 1,
                2 * (i % 4 + 4),
                2 * (i % 4 + 4) + 1,
            ]:
                # 8 weights used for 1 thread (16x16 block) are contiguous in memory via interleaving
                # e.g. T0 uses (0, 16, 128, 144, 8, 24, 136, 152)
                perm1.append(16 * row + col + 8 * block)
        # 1 128bit (int4) == 4 32bit, 1 thread loads 128bit at once, so needs 4 16x16 == 1 16x64 blocks
        for j in range(4):
            # 32 weights loaded by 1 thread (16x64 block) are contiguous in memory via interleaving
            # e.g. T0 uses ((0 ~ 152) + 0 * 256, (0 ~ 152) + 1 * 256, ..., (0 ~ 152) + 3 * 256)
            perm.extend([p + 256 * j for p in perm1])
    perm = np.array(perm)
    # for faster dequant
    # check https://arxiv.org/pdf/2211.10017
    interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
    perm = perm.reshape((-1, 8))[:, interleave].ravel()
    perm = torch.from_numpy(perm)
    return perm