in tensorflow/tensorflow/lite/experimental/ruy/kernel_arm64.cc [2204:3546]
void Kernel8bitNeonDotprodOutOfOrder(const KernelParams8bit<8, 8>& params) {
gemmlowp::ScopedProfilingLabel label(
"Kernel (kNeonDotprod, optimized for out-of-order cores)");
CheckOffsetsInKernelParams8bit(params);
const std::int8_t* lhs_col_ptr = params.lhs_base_ptr;
const std::int8_t* rhs_col_ptr = params.rhs_base_ptr;
const std::int8_t* lhs_ptr = lhs_col_ptr;
const std::int8_t* rhs_ptr = rhs_col_ptr;
void* dst_col_ptr = params.dst_base_ptr;
void* dst_ptr = dst_col_ptr;
int row = params.start_row;
int col = params.start_col;
// The asm kernel below has the following NEON register allocation:
//
// v16 -- v31 are int32 accumulators.
// During accumulation, v0 -- v15 are used to load int8 data from LHS and
// RHS. At least v0 and v1 are used to load a 8x4 block of LHS, and v2 and
// v3 are used to load a 4x8 block of RHS, like this:
//
// int8 RHS 4x8 block
// /-----------------------------------------\
// |v2.b[0] ... v2.b[12] v3.b[0] ... v3.b[12]|
// | ... ... |
// |v2.b[3] ... v2.b[15] v3.b[3] ... v3.b[15]|
// \-----------------------------------------/
// int8 LHS 8x4 block
// /---------------------\ /-----------------------------------------\
// |v0.b[0] ... v0.b[3] | |v16.s[0] ... v30.s[0]|
// | ... ... | | ... ... |
// |v0.b[12] ... v0.b[15]| |v16.s[3] ... v30.s[3]|
// |v1.b[0] ... v1.b[3] | |v17.s[0] ... v31.s[0]|
// | ... ... | | ... ... |
// |v1.b[12] ... v1.b[15]| |v17.s[3] ... v31.s[3]|
// \---------------------/ \-----------------------------------------/
// int32 accumulators 8x8 block
//
// In the RUY_OPT_MAX_STREAMING part of the kernel, this elementary step
// is repeated 4 times, using 4x more registers for LHS and RHS, so that
// is where instead of using v0 -- v3 for LHS and RHS, we use v0 -- v15.
//
// Outside of the RUY_OPT_MAX_STREAMING part of the kernel, v4 -- v7 are
// unused, and v8 -- v15 are used for loading parameters used for the
// post-accumulation part of the kernel.
asm volatile(
#define RUY_MAKE_ZERO(reg) "dup " #reg ".4s, wzr\n"
// clang-format off
// Load some parameters into registers.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
"ldr w8, [%[params], #" RUY_STR(RUY_OFFSET_LAST_COL) "]\n"
"ldr w9, [%[params], #" RUY_STR(RUY_OFFSET_LHS_STRIDE) "]\n"
"ldr w10, [%[params], #" RUY_STR(RUY_OFFSET_RHS_STRIDE) "]\n"
"ldr w11, [%[params], #" RUY_STR(RUY_OFFSET_DST_STRIDE) "]\n"
"ldr w12, [%[params], #" RUY_STR(RUY_OFFSET_DEPTH) "]\n"
// Load the first 32 bytes of LHS and RHS data.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Clear accumulators.
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
// Perform the first few multiply-adds on the data that we have already
// loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// Main loop of the whole GEMM, over rows and columns of the
// destination matrix.
"1:\n"
// Optional, maximally-streaming, partial-unrolling (4x unrolled)
// optimization of the kernel inner loop (over depth). For more
// comments, see the non-unrolled loop below after the #endif.
#if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
"cmp w12, #32\n"
"blt 78f\n"
"ld1 {v4.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v5.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v6.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v7.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v8.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v9.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v10.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v11.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v12.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v13.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v14.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v15.16b}, [%[rhs_ptr]], #16\n"
"mov w1, #16\n"
"and w3, w12, #-16\n"
"81:\n"
"add w1, w1, #16\n"
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ldr q0, [%[lhs_ptr], #0]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ldr q2, [%[rhs_ptr], #0]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ldr q1, [%[lhs_ptr], #16]\n"
".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
"ldr q3, [%[rhs_ptr], #16]\n"
".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
"ldr q5, [%[lhs_ptr], #48]\n"
".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
"ldr q7, [%[rhs_ptr], #48]\n"
".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
"ldr q4, [%[lhs_ptr], #32]\n"
".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
"ldr q6, [%[rhs_ptr], #32]\n"
".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
"ldr q9, [%[lhs_ptr], #80]\n"
".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
"ldr q11, [%[rhs_ptr], #80]\n"
".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
"ldr q8, [%[lhs_ptr], #64]\n"
".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
"ldr q10, [%[rhs_ptr], #64]\n"
".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
"add %[lhs_ptr], %[lhs_ptr], #128\n"
".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
"add %[rhs_ptr], %[rhs_ptr], #128\n"
".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
"cmp w1, w3\n"
".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
"ldr q13, [%[lhs_ptr], #-16]\n"
".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
"ldr q15, [%[rhs_ptr], #-16]\n"
".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
"ldr q12, [%[lhs_ptr], #-32]\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
"ldr q14, [%[rhs_ptr], #-32]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"blt 81b\n"
".word 0x4f87e098 // sdot v24.4s, v4.16b, v7.4b[0]\n"
".word 0x4fa7e09a // sdot v26.4s, v4.16b, v7.4b[1]\n"
".word 0x4f87e89c // sdot v28.4s, v4.16b, v7.4b[2]\n"
".word 0x4fa7e89e // sdot v30.4s, v4.16b, v7.4b[3]\n"
".word 0x4f86e0b1 // sdot v17.4s, v5.16b, v6.4b[0]\n"
".word 0x4fa6e0b3 // sdot v19.4s, v5.16b, v6.4b[1]\n"
".word 0x4f86e8b5 // sdot v21.4s, v5.16b, v6.4b[2]\n"
".word 0x4fa6e8b7 // sdot v23.4s, v5.16b, v6.4b[3]\n"
".word 0x4f87e0b9 // sdot v25.4s, v5.16b, v7.4b[0]\n"
".word 0x4fa7e0bb // sdot v27.4s, v5.16b, v7.4b[1]\n"
".word 0x4f87e8bd // sdot v29.4s, v5.16b, v7.4b[2]\n"
".word 0x4fa7e8bf // sdot v31.4s, v5.16b, v7.4b[3]\n"
".word 0x4f86e090 // sdot v16.4s, v4.16b, v6.4b[0]\n"
".word 0x4fa6e092 // sdot v18.4s, v4.16b, v6.4b[1]\n"
".word 0x4f86e894 // sdot v20.4s, v4.16b, v6.4b[2]\n"
".word 0x4fa6e896 // sdot v22.4s, v4.16b, v6.4b[3]\n"
".word 0x4f8be118 // sdot v24.4s, v8.16b, v11.4b[0]\n"
".word 0x4fabe11a // sdot v26.4s, v8.16b, v11.4b[1]\n"
".word 0x4f8be91c // sdot v28.4s, v8.16b, v11.4b[2]\n"
".word 0x4fabe91e // sdot v30.4s, v8.16b, v11.4b[3]\n"
".word 0x4f8ae131 // sdot v17.4s, v9.16b, v10.4b[0]\n"
".word 0x4faae133 // sdot v19.4s, v9.16b, v10.4b[1]\n"
".word 0x4f8ae935 // sdot v21.4s, v9.16b, v10.4b[2]\n"
".word 0x4faae937 // sdot v23.4s, v9.16b, v10.4b[3]\n"
".word 0x4f8be139 // sdot v25.4s, v9.16b, v11.4b[0]\n"
".word 0x4fabe13b // sdot v27.4s, v9.16b, v11.4b[1]\n"
".word 0x4f8be93d // sdot v29.4s, v9.16b, v11.4b[2]\n"
".word 0x4fabe93f // sdot v31.4s, v9.16b, v11.4b[3]\n"
".word 0x4f8ae110 // sdot v16.4s, v8.16b, v10.4b[0]\n"
".word 0x4faae112 // sdot v18.4s, v8.16b, v10.4b[1]\n"
".word 0x4f8ae914 // sdot v20.4s, v8.16b, v10.4b[2]\n"
".word 0x4faae916 // sdot v22.4s, v8.16b, v10.4b[3]\n"
".word 0x4f8fe198 // sdot v24.4s, v12.16b, v15.4b[0]\n"
".word 0x4fafe19a // sdot v26.4s, v12.16b, v15.4b[1]\n"
".word 0x4f8fe99c // sdot v28.4s, v12.16b, v15.4b[2]\n"
".word 0x4fafe99e // sdot v30.4s, v12.16b, v15.4b[3]\n"
".word 0x4f8ee1b1 // sdot v17.4s, v13.16b, v14.4b[0]\n"
".word 0x4faee1b3 // sdot v19.4s, v13.16b, v14.4b[1]\n"
".word 0x4f8ee9b5 // sdot v21.4s, v13.16b, v14.4b[2]\n"
".word 0x4faee9b7 // sdot v23.4s, v13.16b, v14.4b[3]\n"
".word 0x4f8fe1b9 // sdot v25.4s, v13.16b, v15.4b[0]\n"
".word 0x4fafe1bb // sdot v27.4s, v13.16b, v15.4b[1]\n"
".word 0x4f8fe9bd // sdot v29.4s, v13.16b, v15.4b[2]\n"
".word 0x4fafe9bf // sdot v31.4s, v13.16b, v15.4b[3]\n"
".word 0x4f8ee190 // sdot v16.4s, v12.16b, v14.4b[0]\n"
".word 0x4faee192 // sdot v18.4s, v12.16b, v14.4b[1]\n"
".word 0x4f8ee994 // sdot v20.4s, v12.16b, v14.4b[2]\n"
".word 0x4faee996 // sdot v22.4s, v12.16b, v14.4b[3]\n"
"78:\n"
#endif // #if RUY_OPT_ENABLED(RUY_OPT_MAX_STREAMING)
// Ordinary kernel inner loop (over depth), the simpler loop that the
// above was an equivalent 4x-partially-unrolled version of.
// Reminder - w1 is how many levels of depth we have already loaded
// data for, w12 is the total depth.
"cmp w1, w12\n"
"beq 79f\n"
"2:\n"
// Because of the data that we have already loaded, we can start the
// loop body right away with some multiply-adds.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
// Each iteration of this loop advances by 4 levels of depth.
"add w1, w1, #4\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
// Loop termination condition.
"cmp w1, w12\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"blt 2b\n"
"79:\n"
// End of the inner loop on depth. Now perform the remaining
// multiply-adds of the last 4 levels of depth, for which the LHS
// and RHS data is already loaded.
".word 0x4f83e018 // sdot v24.4s, v0.16b, v3.4b[0]\n"
".word 0x4fa3e01a // sdot v26.4s, v0.16b, v3.4b[1]\n"
".word 0x4f83e81c // sdot v28.4s, v0.16b, v3.4b[2]\n"
".word 0x4fa3e81e // sdot v30.4s, v0.16b, v3.4b[3]\n"
".word 0x4f82e031 // sdot v17.4s, v1.16b, v2.4b[0]\n"
".word 0x4fa2e033 // sdot v19.4s, v1.16b, v2.4b[1]\n"
".word 0x4f82e835 // sdot v21.4s, v1.16b, v2.4b[2]\n"
".word 0x4fa2e837 // sdot v23.4s, v1.16b, v2.4b[3]\n"
".word 0x4f83e039 // sdot v25.4s, v1.16b, v3.4b[0]\n"
".word 0x4fa3e03b // sdot v27.4s, v1.16b, v3.4b[1]\n"
".word 0x4f83e83d // sdot v29.4s, v1.16b, v3.4b[2]\n"
".word 0x4fa3e83f // sdot v31.4s, v1.16b, v3.4b[3]\n"
// End of accumulation. The registers v16 -- v31 contain the final
// int32 accumulator values of the current 8x8 destination block.
// We now have to compute the final 8-bit values from these int32
// accumulators, and advance to the next 8x8 block. We intertwine
// these two aspects whenever possible for optimal pipelining, both
// at the data flow level (prefetch data for next block as early as
// possible) and instruction pipelining level (some of the next-block
// work can dual-issue with some of the final work on the current
// block).
// Logic to advance to the next block in preparation for the next
// iteration of the main loop. For now, we only want to compute
// the LHS and RHS data pointers, lhs_col_ptr and rhs_col_ptr. We are
// not yet ready to update the values of row and col, as we still need
// the current values for the rest of the work on the current block.
"cmp %w[row], w7\n" // Have we finished the last row?
"bge 4f\n" // If finished last row, go to 4
// Not finished last row: then advance to next row.
"add %[lhs_col_ptr], %[lhs_col_ptr], x9, lsl #3\n"
"b 5f\n"
"4:\n" // Finished last row...
"mov %[lhs_col_ptr], x5\n" // Go back to first row
// Now we need to advance to the next column. If we already
// finished the last column, then in principle we are done, however
// we can't just return here, as we need to allow the end work of the
// current block to complete. The good news is that at this point it
// doesn't matter what data we load for the next column, since
// we will exit from the main loop below before actually storing
// anything computed from that data.
"cmp %w[col], w8\n" // Have we finished the last column?
"bge 5f\n" // If yes, just carry on without updating the column pointer.
// Not finished last column: then advance to next column.
"add %[rhs_col_ptr], %[rhs_col_ptr], x10, lsl #3\n"
"5:\n"
// Set the LHS and RHS data pointers to the start of the columns just
// computed.
"mov %[lhs_ptr], %[lhs_col_ptr]\n"
"mov %[rhs_ptr], %[rhs_col_ptr]\n"
// Load some parameters needed for the end work on current block.
RUY_MAKE_ZERO(v8)
"ldr w4, [%[params], #" RUY_STR(RUY_OFFSET_DST_ZERO_POINT) "]\n"
"ldr w3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
"ins v13.h[4], w4\n" // dst_zero_point
"ldr x4, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
"ldrb w6, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
"dup v9.4s, w3\n" // create prod_zp_depth_vec
"add x5, x4, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"csel x4, x4, x5, eq\n"
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
"add x5, x1, %x[row], lsl #2\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
"csel x1, x1, x5, eq\n"
// Load 8 bias values.
"ld1 {v14.4s}, [x1], #16\n"
"ld1 {v15.4s}, [x1]\n"
// Now that we know what LHS and RHS data the next iteration of the
// main loop will need to load, we start loading the first 32 bytes of
// each of LHS and RHS, into v0 -- v3, as we don't need v0 -- v3 anymore
// in the rest of the work on the current block.
"ld1 {v0.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v1.16b}, [%[lhs_ptr]], #16\n"
"ld1 {v2.16b}, [%[rhs_ptr]], #16\n"
"ld1 {v3.16b}, [%[rhs_ptr]], #16\n"
// Add to the bias values the product (depth * lhs_zero_point * rhs_zero_point),
// See the term NZ1Z2 in equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"add v14.4s, v14.4s, v9.4s\n"
"add v15.4s, v15.4s, v9.4s\n"
// Perform the bias-addition (per the above, we have just folded into
// the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
"add v16.4s, v16.4s, v14.4s\n"
"add v17.4s, v17.4s, v15.4s\n"
"add v18.4s, v18.4s, v14.4s\n"
"add v19.4s, v19.4s, v15.4s\n"
"add v20.4s, v20.4s, v14.4s\n"
"add v21.4s, v21.4s, v15.4s\n"
"add v22.4s, v22.4s, v14.4s\n"
"add v23.4s, v23.4s, v15.4s\n"
"add v24.4s, v24.4s, v14.4s\n"
"add v25.4s, v25.4s, v15.4s\n"
"add v26.4s, v26.4s, v14.4s\n"
"add v27.4s, v27.4s, v15.4s\n"
"add v28.4s, v28.4s, v14.4s\n"
"add v29.4s, v29.4s, v15.4s\n"
"add v30.4s, v30.4s, v14.4s\n"
"add v31.4s, v31.4s, v15.4s\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_RHS_SUMS) "\n"
"beq 401f\n"
"ldr x3, [%[params], #" RUY_STR(RUY_OFFSET_RHS_SUMS) "]\n"
"add x3, x3, %x[col], lsl #2\n"
"ld1 {v14.4s}, [x3], #16\n"
"ld1 {v15.4s}, [x3]\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_ZERO_POINT) "]\n"
"dup v10.4s, w5\n" // create lhs_zero_point_vec
// Subtract rhs_sums * lhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"mls v16.4s, v10.4s, v14.s[0]\n"
"mls v17.4s, v10.4s, v14.s[0]\n"
"mls v18.4s, v10.4s, v14.s[1]\n"
"mls v19.4s, v10.4s, v14.s[1]\n"
"mls v20.4s, v10.4s, v14.s[2]\n"
"mls v21.4s, v10.4s, v14.s[2]\n"
"mls v22.4s, v10.4s, v14.s[3]\n"
"mls v23.4s, v10.4s, v14.s[3]\n"
"mls v24.4s, v10.4s, v15.s[0]\n"
"mls v25.4s, v10.4s, v15.s[0]\n"
"mls v26.4s, v10.4s, v15.s[1]\n"
"mls v27.4s, v10.4s, v15.s[1]\n"
"mls v28.4s, v10.4s, v15.s[2]\n"
"mls v29.4s, v10.4s, v15.s[2]\n"
"mls v30.4s, v10.4s, v15.s[3]\n"
"mls v31.4s, v10.4s, v15.s[3]\n"
"401:\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_LHS_SUMS) "\n"
"beq 402f\n"
"ldr x2, [%[params], #" RUY_STR(RUY_OFFSET_LHS_SUMS) "]\n"
"add x2, x2, %x[row], lsl #2\n"
"ldr w5, [%[params], #" RUY_STR(RUY_OFFSET_RHS_ZERO_POINT) "]\n"
// Load 4 lhs_sums values.
"ld1 {v11.4s}, [x2], #16\n"
"ld1 {v12.4s}, [x2]\n"
"ins v13.s[1], w5\n" // rhs_zero_point
// Compute lhs_sums * rhs_zero_point.
"mul v11.4s, v11.4s, v13.s[1]\n"
"mul v12.4s, v12.4s, v13.s[1]\n"
// Subtract lhs_sums * rhs_zero_point, per
// equation (7) in https://arxiv.org/pdf/1712.05877.pdf
"sub v16.4s, v16.4s, v11.4s\n"
"sub v17.4s, v17.4s, v12.4s\n"
"sub v18.4s, v18.4s, v11.4s\n"
"sub v19.4s, v19.4s, v12.4s\n"
"sub v20.4s, v20.4s, v11.4s\n"
"sub v21.4s, v21.4s, v12.4s\n"
"sub v22.4s, v22.4s, v11.4s\n"
"sub v23.4s, v23.4s, v12.4s\n"
"sub v24.4s, v24.4s, v11.4s\n"
"sub v25.4s, v25.4s, v12.4s\n"
"sub v26.4s, v26.4s, v11.4s\n"
"sub v27.4s, v27.4s, v12.4s\n"
"sub v28.4s, v28.4s, v11.4s\n"
"sub v29.4s, v29.4s, v12.4s\n"
"sub v30.4s, v30.4s, v11.4s\n"
"sub v31.4s, v31.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT32) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT32) "f\n"
"402:\n"
// At this point we have computed the final int32 values. Now we
// start down-quantizing them to obtain the final 8bit values from them.
// As part of this down-quantization, our int32 values will be
// multiplied by a multiplier that has a fixed-point component and an
// exponent component.
//Load the exponent part of the multiplier.
"ldr x1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
"add x5, x1, %x[row], lsl #2\n"
"csel x1, x1, x5, eq\n"
"ldr q9, [x1]\n"
"ldr q10, [x1, #16]\n"
"tst w6, #" RUY_STR(RUY_ASM_FLAG_NEEDS_LEFT_SHIFT) "\n"
"beq 403f\n"
"smax v11.4s, v9.4s, v8.4s\n"
"smax v12.4s, v10.4s, v8.4s\n"
"sshl v16.4s, v16.4s, v11.4s\n"
"sshl v17.4s, v17.4s, v12.4s\n"
"sshl v18.4s, v18.4s, v11.4s\n"
"sshl v19.4s, v19.4s, v12.4s\n"
"sshl v20.4s, v20.4s, v11.4s\n"
"sshl v21.4s, v21.4s, v12.4s\n"
"sshl v22.4s, v22.4s, v11.4s\n"
"sshl v23.4s, v23.4s, v12.4s\n"
"sshl v24.4s, v24.4s, v11.4s\n"
"sshl v25.4s, v25.4s, v12.4s\n"
"sshl v26.4s, v26.4s, v11.4s\n"
"sshl v27.4s, v27.4s, v12.4s\n"
"sshl v28.4s, v28.4s, v11.4s\n"
"sshl v29.4s, v29.4s, v12.4s\n"
"sshl v30.4s, v30.4s, v11.4s\n"
"sshl v31.4s, v31.4s, v12.4s\n"
"403:\n"
"ldr q14, [x4]\n" // multiplier_fixedpoint
"ldr q15, [x4, #16]\n" // multiplier_fixedpoint
"smin v11.4s, v9.4s, v8.4s\n"
"smin v12.4s, v10.4s, v8.4s\n"
// Apply the fixed-point part of the multiplier.
"sqrdmulh v16.4s, v16.4s, v14.4s\n"
"sqrdmulh v17.4s, v17.4s, v15.4s\n"
"sqrdmulh v18.4s, v18.4s, v14.4s\n"
"sqrdmulh v19.4s, v19.4s, v15.4s\n"
"sqrdmulh v20.4s, v20.4s, v14.4s\n"
"sqrdmulh v21.4s, v21.4s, v15.4s\n"
"sqrdmulh v22.4s, v22.4s, v14.4s\n"
"sqrdmulh v23.4s, v23.4s, v15.4s\n"
"sqrdmulh v24.4s, v24.4s, v14.4s\n"
"sqrdmulh v25.4s, v25.4s, v15.4s\n"
"sqrdmulh v26.4s, v26.4s, v14.4s\n"
"sqrdmulh v27.4s, v27.4s, v15.4s\n"
"sqrdmulh v28.4s, v28.4s, v14.4s\n"
"sqrdmulh v29.4s, v29.4s, v15.4s\n"
"sqrdmulh v30.4s, v30.4s, v14.4s\n"
"sqrdmulh v31.4s, v31.4s, v15.4s\n"
// We have some rounding division-by-power-of-two to do. This should
// always use "round to nearest". We allow for some
// freedom in how ties are broken, to strike a good compromise of
// performance on given hardware vs. perfect agreement of results
// across hardware.
//
// When RUY_OPT_NATIVE_ROUNDING is enabled, we allow for implementation
// defined tie-breaks to help performance. On NEON, this means that we
// can just use the NEON rounding instructions, such as srshl. They
// happen to be breaking ties upward.
//
// When RUY_OPT_NATIVE_ROUNDING is disabled, we implement strict
// break-ties-away-from zero, as described in Appendix B of
// https://arxiv.org/pdf/1712.05877.pdf
// When we wrote that, we thought that that would be better unbiased
// than the NEON upwards tie-breaks, and we had observed some
// improvement on some model. However, that is only more unbiased for
// data centered at zero, which was likely the case in that model,
// but is not always the case. If we wanted something more consistently
// unbiased then we should try breaking ties toward-nearest-even.
#if !RUY_OPT_ENABLED(RUY_OPT_NATIVE_ROUNDING)
// Fix up values to be right-shifted, so that the (round to nearest,
// break ties upward) behavior of srshl applied to these fixed-up
// values, produces the same result as the desired (round to nearest,
// break ties away from zero) behavior on the original values.
"and v8.16b, v16.16b, v11.16b\n"
"and v9.16b, v17.16b, v12.16b\n"
"and v14.16b, v18.16b, v11.16b\n"
"and v15.16b, v19.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v16.4s, v16.4s, v8.4s\n"
"sqadd v17.4s, v17.4s, v9.4s\n"
"sqadd v18.4s, v18.4s, v14.4s\n"
"sqadd v19.4s, v19.4s, v15.4s\n"
"and v8.16b, v20.16b, v11.16b\n"
"and v9.16b, v21.16b, v12.16b\n"
"and v14.16b, v22.16b, v11.16b\n"
"and v15.16b, v23.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v20.4s, v20.4s, v8.4s\n"
"sqadd v21.4s, v21.4s, v9.4s\n"
"sqadd v22.4s, v22.4s, v14.4s\n"
"sqadd v23.4s, v23.4s, v15.4s\n"
"and v8.16b, v24.16b, v11.16b\n"
"and v9.16b, v25.16b, v12.16b\n"
"and v14.16b, v26.16b, v11.16b\n"
"and v15.16b, v27.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v24.4s, v24.4s, v8.4s\n"
"sqadd v25.4s, v25.4s, v9.4s\n"
"sqadd v26.4s, v26.4s, v14.4s\n"
"sqadd v27.4s, v27.4s, v15.4s\n"
"and v8.16b, v28.16b, v11.16b\n"
"and v9.16b, v29.16b, v12.16b\n"
"and v14.16b, v30.16b, v11.16b\n"
"and v15.16b, v31.16b, v12.16b\n"
"sshr v8.4s, v8.4s, #31\n"
"sshr v9.4s, v9.4s, #31\n"
"sshr v14.4s, v14.4s, #31\n"
"sshr v15.4s, v15.4s, #31\n"
"sqadd v28.4s, v28.4s, v8.4s\n"
"sqadd v29.4s, v29.4s, v9.4s\n"
"sqadd v30.4s, v30.4s, v14.4s\n"
"sqadd v31.4s, v31.4s, v15.4s\n"
#endif
// At this point we have reduced the problem of correctly implementing
// rounding divide-by-power-of-two, to what the SRSHL instruction can
// do.
"srshl v16.4s, v16.4s, v11.4s\n"
"srshl v17.4s, v17.4s, v12.4s\n"
"srshl v18.4s, v18.4s, v11.4s\n"
"srshl v19.4s, v19.4s, v12.4s\n"
"srshl v20.4s, v20.4s, v11.4s\n"
"srshl v21.4s, v21.4s, v12.4s\n"
"srshl v22.4s, v22.4s, v11.4s\n"
"srshl v23.4s, v23.4s, v12.4s\n"
"srshl v24.4s, v24.4s, v11.4s\n"
"srshl v25.4s, v25.4s, v12.4s\n"
"srshl v26.4s, v26.4s, v11.4s\n"
"srshl v27.4s, v27.4s, v12.4s\n"
"srshl v28.4s, v28.4s, v11.4s\n"
"srshl v29.4s, v29.4s, v12.4s\n"
"srshl v30.4s, v30.4s, v11.4s\n"
"srshl v31.4s, v31.4s, v12.4s\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT16) "f\n"
"cmp %w[dst_type_id], #" RUY_STR(RUY_ASM_TYPE_ID_INT8) "\n"
"beq " RUY_STR(RUY_ASM_LABEL_STORE_INT8) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_UINT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
"add v17.8h, v17.8h, v14.8h\n"
"add v18.8h, v18.8h, v14.8h\n"
"add v19.8h, v19.8h, v14.8h\n"
"add v20.8h, v20.8h, v14.8h\n"
"add v21.8h, v21.8h, v14.8h\n"
"add v22.8h, v22.8h, v14.8h\n"
"add v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtun v16.8b, v16.8h\n"
"sqxtun2 v16.16b, v17.8h\n"
"sqxtun v17.8b, v18.8h\n"
"sqxtun2 v17.16b, v19.8h\n"
"sqxtun v18.8b, v20.8h\n"
"sqxtun2 v18.16b, v21.8h\n"
"sqxtun v19.8b, v22.8h\n"
"sqxtun2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"umax v16.16b, v16.16b, v14.16b\n"
"umax v17.16b, v17.16b, v14.16b\n"
"umax v18.16b, v18.16b, v14.16b\n"
"umax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"umin v16.16b, v16.16b, v15.16b\n"
"umin v17.16b, v17.16b, v15.16b\n"
"umin v18.16b, v18.16b, v15.16b\n"
"umin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 30f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 31f\n"
"30:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"31:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 41f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"50:\n"
"mov w5, #0\n"
"51:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 51b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 50b\n"
"41:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT8) ":\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"add v16.8h, v16.8h, v14.8h\n"
"add v17.8h, v17.8h, v14.8h\n"
"add v18.8h, v18.8h, v14.8h\n"
"add v19.8h, v19.8h, v14.8h\n"
"add v20.8h, v20.8h, v14.8h\n"
"add v21.8h, v21.8h, v14.8h\n"
"add v22.8h, v22.8h, v14.8h\n"
"add v23.8h, v23.8h, v14.8h\n"
// Cast-and-saturate from int16 to uint8
"sqxtn v16.8b, v16.8h\n"
"sqxtn2 v16.16b, v17.8h\n"
"sqxtn v17.8b, v18.8h\n"
"sqxtn2 v17.16b, v19.8h\n"
"sqxtn v18.8b, v20.8h\n"
"sqxtn2 v18.16b, v21.8h\n"
"sqxtn v19.8b, v22.8h\n"
"sqxtn2 v19.16b, v23.8h\n"
// Load the clamp_min, clamp_max bounds
"ldrb w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrb w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.16b, w2\n" // clamp_min
"dup v15.16b, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.16b, v16.16b, v14.16b\n"
"smax v17.16b, v17.16b, v14.16b\n"
"smax v18.16b, v18.16b, v14.16b\n"
"smax v19.16b, v19.16b, v14.16b\n"
// Apply the clamp_max bound
"smin v16.16b, v16.16b, v15.16b\n"
"smin v17.16b, v17.16b, v15.16b\n"
"smin v18.16b, v18.16b, v15.16b\n"
"smin v19.16b, v19.16b, v15.16b\n"
// Make it so that all of the final 8bit values are stored in the
// first 64bits of 128bit NEON registers, so they can be stored
// by 64bit st1 store instructions with byte alignment.
"dup d20, v16.d[1]\n"
"dup d21, v17.d[1]\n"
"dup d22, v18.d[1]\n"
"dup d23, v19.d[1]\n"
// Compute how much of the 8x8 block of destination 8bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w2 = how many cols of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 130f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #8\n"
"b 131f\n"
"130:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"131:\n"
// Write our 8bit values to the destination described by
// (x3 address, x4 stride).
"st1 {v16.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
"st1 {v20.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
"st1 {v17.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
"st1 {v21.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
"st1 {v18.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
"st1 {v22.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
"st1 {v19.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
"st1 {v23.8b}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 141f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"150:\n"
"mov w5, #0\n"
"151:\n"
"ldrb w7, [x3, w5, uxtw]\n"
"strb w7, [x4, w5, uxtw]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 151b\n"
"add w6, w6, #1\n"
"add x3, x3, #8\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 150b\n"
"141:\n"
"add %[dst_ptr], %[dst_ptr], #8\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT16) ":\n"
// Add the destination zero point
"dup v14.8h, v13.h[4]\n"
"saddw v16.4s, v16.4s, v14.4h\n"
"saddw v17.4s, v17.4s, v14.4h\n"
"saddw v18.4s, v18.4s, v14.4h\n"
"saddw v19.4s, v19.4s, v14.4h\n"
"saddw v20.4s, v20.4s, v14.4h\n"
"saddw v21.4s, v21.4s, v14.4h\n"
"saddw v22.4s, v22.4s, v14.4h\n"
"saddw v23.4s, v23.4s, v14.4h\n"
"saddw v24.4s, v24.4s, v14.4h\n"
"saddw v25.4s, v25.4s, v14.4h\n"
"saddw v26.4s, v26.4s, v14.4h\n"
"saddw v27.4s, v27.4s, v14.4h\n"
"saddw v28.4s, v28.4s, v14.4h\n"
"saddw v29.4s, v29.4s, v14.4h\n"
"saddw v30.4s, v30.4s, v14.4h\n"
"saddw v31.4s, v31.4s, v14.4h\n"
// Cast-and-saturate from int32 to int16
"sqxtn v16.4h, v16.4s\n"
"sqxtn2 v16.8h, v17.4s\n"
"sqxtn v17.4h, v18.4s\n"
"sqxtn2 v17.8h, v19.4s\n"
"sqxtn v18.4h, v20.4s\n"
"sqxtn2 v18.8h, v21.4s\n"
"sqxtn v19.4h, v22.4s\n"
"sqxtn2 v19.8h, v23.4s\n"
"sqxtn v20.4h, v24.4s\n"
"sqxtn2 v20.8h, v25.4s\n"
"sqxtn v21.4h, v26.4s\n"
"sqxtn2 v21.8h, v27.4s\n"
"sqxtn v22.4h, v28.4s\n"
"sqxtn2 v22.8h, v29.4s\n"
"sqxtn v23.4h, v30.4s\n"
"sqxtn2 v23.8h, v31.4s\n"
// At this point, v24 -- v31 aren't used anymore for the current block,
// so we can start clearing these accumulators for the next block
// (next iteration of the main loop).
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
// Load the clamp_min, clamp_max bounds
"ldrsh w2, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MIN) "]\n"
"ldrsh w3, [%[params], #" RUY_STR(RUY_OFFSET_CLAMP_MAX) "]\n"
"dup v14.8h, w2\n" // clamp_min
"dup v15.8h, w3\n" // clamp_max
// Apply the clamp_min bound
"smax v16.8h, v16.8h, v14.8h\n"
"smax v17.8h, v17.8h, v14.8h\n"
"smax v18.8h, v18.8h, v14.8h\n"
"smax v19.8h, v19.8h, v14.8h\n"
"smax v20.8h, v20.8h, v14.8h\n"
"smax v21.8h, v21.8h, v14.8h\n"
"smax v22.8h, v22.8h, v14.8h\n"
"smax v23.8h, v23.8h, v14.8h\n"
// Apply the clamp_max bound
"smin v16.8h, v16.8h, v15.8h\n"
"smin v17.8h, v17.8h, v15.8h\n"
"smin v18.8h, v18.8h, v15.8h\n"
"smin v19.8h, v19.8h, v15.8h\n"
"smin v20.8h, v20.8h, v15.8h\n"
"smin v21.8h, v21.8h, v15.8h\n"
"smin v22.8h, v22.8h, v15.8h\n"
"smin v23.8h, v23.8h, v15.8h\n"
// Compute how much of the 8x8 block of destination 16bit values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 230f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
"b 231f\n"
"230:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x3, %[dst_ptr]\n"
"mov x4, x11\n"
"231:\n"
// Write our 16bit values to the destination described by
// (x3 address, x4 stride).
"st1 {v16.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
"st1 {v17.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
"st1 {v18.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
"st1 {v19.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
"st1 {v20.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
"st1 {v21.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
"st1 {v22.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
"st1 {v23.8h}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 241f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"250:\n"
"mov w5, #0\n"
"251:\n"
"ldrsh w7, [x3, x5, lsl #1]\n"
"strh w7, [x4, x5, lsl #1]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 251b\n"
"add w6, w6, #1\n"
"add x3, x3, #16\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 250b\n"
"241:\n"
"add %[dst_ptr], %[dst_ptr], #16\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
"b " RUY_STR(RUY_ASM_LABEL_AFTER_STORE) "f\n"
RUY_STR(RUY_ASM_LABEL_STORE_INT32) ":\n"
// Since the store type is the same as the accum type, no need for
// downcast. There's also no need for clamp by min/max.
// Compute how much of the 8x8 block of destination 32it values that
// we have computed, fit in the destination matrix. Typically, all of
// it fits, but when the destination matrix shape is not a multiple
// of 8x8, there are some 8x8 blocks along the boundaries that do
// not fit entirely.
"sub w1, %w[dst_rows], %w[row]\n"
"sub w2, %w[dst_cols], %w[col]\n"
"mov w3, #8\n"
"cmp w1, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w1, w1, w3, le\n"
"cmp w2, #8\n"
// Compute w1 = how many rows of the 8x8 block fit
"csel w2, w2, w3, le\n"
// Test if w1==8 && w2 == 8, i.e. if all of the 8x8 block fits.
"cmp w1, w3\n"
"ccmp w2, w3, 0, eq\n"
// Yes, all of the 8x8 block fits, go to fast path.
"beq 330f\n"
// Not all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write to dst_tmp_buf
"mov x3, %[dst_tmp_buf]\n"
"mov x4, #16\n"
// Write our 32bit values to the destination described by
// (x3 address, x4 stride).
"st1 {v16.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v16)
"st1 {v17.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v17)
"st1 {v18.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v18)
"st1 {v19.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v19)
"st1 {v20.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v20)
"st1 {v21.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v21)
"st1 {v22.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v22)
"st1 {v23.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v23)
"st1 {v24.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v24)
"st1 {v25.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v25)
"st1 {v26.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v26)
"st1 {v27.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v27)
"st1 {v28.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v28)
"st1 {v29.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v29)
"st1 {v30.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v30)
"st1 {v31.4s}, [x3], x4\n"
RUY_MAKE_ZERO(v31)
"b 331f\n"
"330:\n"
// Yes, all of the 8x8 block fits.
// Set (x3 address, x4 stride) to write directly to destination matrix.
"mov x4, %[dst_ptr]\n"
"mov x3, x4\n"
// Write our 32bit values to the destination described by
// (x3 address, x4 stride).
"st1 {v16.4s, v17.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v16)
RUY_MAKE_ZERO(v17)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v18.4s, v19.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v18)
RUY_MAKE_ZERO(v19)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v20.4s, v21.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v20)
RUY_MAKE_ZERO(v21)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v22.4s, v23.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v22)
RUY_MAKE_ZERO(v23)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v24.4s, v25.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v24)
RUY_MAKE_ZERO(v25)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v26.4s, v27.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v26)
RUY_MAKE_ZERO(v27)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v28.4s, v29.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v28)
RUY_MAKE_ZERO(v29)
"add x4, x4, x11\n"
"mov x3, x4\n"
"st1 {v30.4s, v31.4s}, [x3], #32\n"
RUY_MAKE_ZERO(v30)
RUY_MAKE_ZERO(v31)
"331:\n"
// For the next block: perform the first few multiply-adds on the data
// that we have already loaded.
".word 0x4f82e010 // sdot v16.4s, v0.16b, v2.4b[0]\n"
".word 0x4fa2e012 // sdot v18.4s, v0.16b, v2.4b[1]\n"
".word 0x4f82e814 // sdot v20.4s, v0.16b, v2.4b[2]\n"
".word 0x4fa2e816 // sdot v22.4s, v0.16b, v2.4b[3]\n"
// If all of the 8x8 block fits, we just finished writing it to the
// destination, so we skip the next part.
"beq 341f\n"
// Not all of the 8x8 block fits in the destination matrix. We just
// wrote it to dst_tmp_buf. Now we perform the slow scalar loop over
// it to copy into the destination matrix the part that fits.
"mov x3, %[dst_tmp_buf]\n"
"mov x4, %[dst_ptr]\n"
"mov w6, #0\n"
"350:\n"
"mov w5, #0\n"
"351:\n"
"ldr w7, [x3, x5, lsl #2]\n"
"str w7, [x4, x5, lsl #2]\n"
"add w5, w5, #1\n"
"cmp w5, w1\n"
"blt 351b\n"
"add w6, w6, #1\n"
"add x3, x3, #32\n"
"add x4, x4, x11\n"
"cmp w6, w2\n"
"blt 350b\n"
"341:\n"
"add %[dst_ptr], %[dst_ptr], #32\n"
// At this point we have completely finished writing values to the
// destination matrix for the current block.
RUY_STR(RUY_ASM_LABEL_AFTER_STORE) ":\n"
// Reload some params --- we had used x5 -- x7 for a few other things
// since the last time we had loaded them.
"ldr x5, [%[params], #" RUY_STR(RUY_OFFSET_LHS_BASE_PTR) "]\n"
"ldr w6, [%[params], #" RUY_STR(RUY_OFFSET_START_ROW) "]\n"
"ldr w7, [%[params], #" RUY_STR(RUY_OFFSET_LAST_ROW) "]\n"
// Move to the next block of the destination matrix, for the next iter
// of the main loop. Notice that lhs_col_ptr, rhs_col_ptr have already
// been updated earlier.
// Have we reached the end row?
"cmp %w[row], w7\n"
"beq 20f\n" // yes, end row.
// Not end row. Move to the next row.
"add %w[row], %w[row], #8\n"
"b 21f\n"
"20:\n"
// Was already at end row.
"mov %w[row], w6\n" // Move back to first row.
"add %w[col], %w[col], #8\n" // Move to the next column.
"add %[dst_col_ptr], %[dst_col_ptr], x11, lsl #3\n"
"mov %[dst_ptr], %[dst_col_ptr]\n"
"21:\n"
// Main loop exit condition: have we hit the end column?
"cmp %w[col], w8\n"
// w1 is the number of levels of depth that we have already loaded
// LHS and RHS data for. Corresponding to the initial ld1 instructions
// above, this is currently 4.
"mov w1, #4\n"
"ble 1b\n"
// clang-format on
: [ lhs_col_ptr ] "+r"(lhs_col_ptr), [rhs_col_ptr] "+r"(rhs_col_ptr),
[lhs_ptr] "+r"(lhs_ptr), [rhs_ptr] "+r"(rhs_ptr),
[dst_col_ptr] "+r"(dst_col_ptr), [dst_ptr] "+r"(dst_ptr), [row] "+r"(row), [col] "+r"(col)
: [ params ] "r"(¶ms), [dst_rows] "r"(params.dst_rows),
[dst_cols] "r"(params.dst_cols), [dst_tmp_buf] "r"(params.dst_tmp_buf),
[dst_type_id] "r"(params.dst_type_id)
: "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8", "x9", "x10", "x11", "x12", "x13", "cc",
"memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12",
"v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25",
"v26", "v27", "v28", "v29", "v30", "v31");
}