in optimum/quanto/tensor/weights/marlin/int4/packed.py [0:0]
def _get_perm():
perm = []
# 32 == # of threads in 1 warp
for i in range(32):
perm1 = []
# column id in 16x8 weight block
# check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
col = i // 4
# 1 32bit (int32) == 8 4bit, 1 thread has 4 weights per 16x8 & 4bit weights are packed in int32, so needs 2 16x8 == 1 16x16 blocks
for block in [0, 1]:
# row id in 16x8 weight block
# check https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
# 8 weights used for 1 thread (16x16 block) are contiguous in memory via interleaving
# e.g. T0 uses (0, 16, 128, 144, 8, 24, 136, 152)
perm1.append(16 * row + col + 8 * block)
# 1 128bit (int4) == 4 32bit, 1 thread loads 128bit at once, so needs 4 16x16 == 1 16x64 blocks
for j in range(4):
# 32 weights loaded by 1 thread (16x64 block) are contiguous in memory via interleaving
# e.g. T0 uses ((0 ~ 152) + 0 * 256, (0 ~ 152) + 1 * 256, ..., (0 ~ 152) + 3 * 256)
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
# for faster dequant
# check https://arxiv.org/pdf/2211.10017
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
return perm