in src/UtilsAvx512.cc [309:590]
void transpose_avx512(
int64_t M,
int64_t N,
const float* src,
unsigned ld_src,
float* dst,
unsigned ld_dst) {
unsigned ib = 0, jb = 0;
if (N % 16 > 0 && N % 16 < 4) {
// If the remainder has n < 4 columns, we use the SSE kernel for the
// remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N instructions
// instead of 4 * 16 + 2 * N = 64 + 2N instructions needed in the masked
// AVX512 kernel.
for (ib = 0; ib + 16 <= M; ib += 16) {
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (unsigned i = ib; i < ib + 16; i += 4) {
transpose_kernel_mxn_sse<4>(
N - jb,
&src[i * ld_src + jb],
ld_src,
&dst[i + jb * ld_dst],
ld_dst);
}
}
} else if (N % 16 == 4) {
// If the remainder has 4 columns, we use the SSE kernel for the remainder
// because it requires 4 * 16 = 64 instructions instead of 4 * 16 + 2 * 4 =
// 72 instructions needed in the masked AVX512 kernel.
for (ib = 0; ib + 16 <= M; ib += 16) {
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (unsigned i = ib; i < ib + 16; i += 4) {
transpose_kernel_4x4_sse(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
}
} else if (N % 16 == 8) {
// If the remainder has 8 columns, we use the AVX kenrel for the remainder
// because it requires 2 * 40 = 80 instructions instead of 4 * 16 + 2 * 8 =
// 80 instructions + looping overhead in the masked AVX512 kernel.
for (ib = 0; ib + 16 <= M; ib += 16) {
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (unsigned i = ib; i < ib + 16; i += 8) {
transpose_kernel_8x8_avx2(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
}
} else {
for (ib = 0; ib + 16 <= M; ib += 16) {
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<16>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
}
}
// Specialization for small M - ib cases so that the compiler can inline
// transpose_kernel_mxn_avx512 and unroll the loops whose iteration count
// depends on by M - ib .
// Specialization for m helps more than for n in transpose_kernel_mxn_avx512
// because we have more loops in that function whose iteration count depends
// on m.
switch (M - ib) {
case 1:
for (unsigned j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
case 2:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_sse<2>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_sse<2>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 3:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_mxn_sse<3>(
4, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_sse<3>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 4:
for (jb = 0; jb + 4 <= N; jb += 4) {
transpose_kernel_4x4_sse(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_sse<4>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 5:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_avx2<5>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx2<5>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 6:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_mxn_avx2<6>(
8, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx2<6>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 7:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<7>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<7>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 8:
for (jb = 0; jb + 8 <= N; jb += 8) {
transpose_kernel_8x8_avx2(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx2<8>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 9:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<9>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<9>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 10:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<10>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<10>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 11:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<11>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<11>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 12:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<12>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<12>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 13:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<13>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<13>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 14:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<14>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<14>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
case 15:
for (jb = 0; jb + 16 <= N; jb += 16) {
transpose_kernel_mxn_avx512<15>(
16, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx512<15>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
break;
}
}