in src/UtilsAvx2.cc [174:296]
void transpose_avx2(
unsigned M,
unsigned N,
const uint8_t* src,
unsigned ld_src,
uint8_t* dst,
unsigned ld_dst) {
unsigned ib = 0, jb = 0;
if (M >= 8) {
for (ib = 0; ib + 8 <= M; ib += 8) {
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_8x32_avx2(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N) {
transpose_kernel_mxn_avx2_uint8<8>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
}
}
}
// Specialization for small M - ib cases
switch (M - ib) {
case 1:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<1>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<1>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 2:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<2>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<2>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 3:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<3>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<3>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 4:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<4>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<4>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 5:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<5>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<5>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 6:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<6>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<6>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
case 7:
for (jb = 0; jb + 32 <= N; jb += 32) {
transpose_kernel_mxn_avx2_uint8<7>(
32, &src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
if (jb < N)
transpose_kernel_mxn_avx2_uint8<7>(
N - jb,
&src[ib * ld_src + jb],
ld_src,
&dst[ib + jb * ld_dst],
ld_dst);
break;
}
}