maga_transformer/cpp/devices/arm_impl/gemm_opt/ArmGemmThreadblock.cc (875 lines of code) (raw):
#include <arm_sve.h>
#include <cstring>
// #define DEBUG
#ifdef DEBUG
#include <iomanip>
#endif
#include "ArmGemmKernel.h"
#include "gemm_microkernel_macro_m8_bf16.h"
#include "activation_const.hpp"
#include "arm_common.h"
namespace rtp_llm {
void GemmKernel::thread_block_bf16_m8(
GemmPartParam<hie::bfloat16, hie::bfloat16, float, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2; // [m, k*2], *2 is for processing 2*k_tile per kernel
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp32_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16; // floor
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
// "mov x7, #0 \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* calculate the remaining A and B */
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x4, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT
"whilelt p1.s, x3, %[N] \n" // compare n, N
"add x6, x3, #4 \n" // n + 2
"whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_ACCUMULATE
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_STORE
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
return;
}
/*********************************************************/
void GemmKernel::thread_block_bf16_m8(
GemmPartParam<hie::bfloat16, hie::bfloat16, float16_t, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2; // [m, k*2], *2 is for processing 2*k_tile per kernel
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp16_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float16_t);
float* bias_ptr = p.bias_ptr + n; // TODO: handle float16_t bias
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16; // floor
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
// "mov x7, #0 \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* calculate the remaining A and B */
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x4, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT_FP16
"whilelt p1.h, x3, %[N] \n" // compare n, N
// "add x6, x3, #4 \n" // n + 4
// "whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5", "p3",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_ACCUMULATE_FP16
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_STORE_FP16
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
return;
}
/*********************************************************/
void GemmKernel::thread_block_bf16_m8_mres(
GemmPartParam<hie::bfloat16, hie::bfloat16, float, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp32_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x0, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT
"whilelt p1.s, x3, %[N] \n" // compare n, N
"add x6, x3, #4 \n" // n + 2
"whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_RES_ACCUMULATE
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_RES_STORE
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
void GemmKernel::thread_block_bf16_m8_mres(
GemmPartParam<hie::bfloat16, hie::bfloat16, float16_t, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp16_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float16_t);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x0, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT_FP16
"whilelt p1.h, x3, %[N] \n" // compare n, N
// "add x6, x3, #4 \n" // n + 4
// "whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5", "p3",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_RES_ACCUMULATE_FP16
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_RES_STORE_FP16
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
/*********************************************************/
void GemmKernel::thread_block_bf16_m8_nres(
GemmPartParam<hie::bfloat16, hie::bfloat16, float, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2; // 2 --> sizeof(bfloat16)
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp32_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
// "mov x7, #0 \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x4, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT
"whilelt p1.s, x3, %[N] \n" // compare n, N
"add x6, x3, #4 \n" // n + 2
"whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_ACCUMULATE
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_STORE
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
void GemmKernel::thread_block_bf16_m8_nres(
GemmPartParam<hie::bfloat16, hie::bfloat16, float16_t, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2; // 2 --> sizeof(bfloat16)
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp16_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float16_t);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
// "mov x7, #0 \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x4, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT_FP16
"whilelt p1.h, x3, %[N] \n" // compare n, N
// "add x6, x3, #4 \n" // n + 4
// "whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5", "p3",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_ACCUMULATE_FP16
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_STORE_FP16
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
/*********************************************************/
void GemmKernel::thread_block_bf16_m8_res(
GemmPartParam<hie::bfloat16, hie::bfloat16, float, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp32_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x0, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT
"whilelt p1.s, x3, %[N] \n" // compare n, N
"add x6, x3, #4 \n" // n + 2
"whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_RES_ACCUMULATE
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_RES_STORE
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
void GemmKernel::thread_block_bf16_m8_res(
GemmPartParam<hie::bfloat16, hie::bfloat16, float16_t, float>& p, int m, int n, int k, int k_tile) {
#define LABEL_FOR_LOOP_K "1"
#define LABEL_SKIP_PRF "2"
#define LABEL_SKIP_STORE "3"
#define LABEL_SKIP_LD_A1 "4"
#define LABEL_SKIP_LD_W1 "5"
#define LABEL_SKIP_ACCUMULATE "6"
int M = p.M;
int N = p.N;
hie::bfloat16* a_bf16_ptr1 = p.a_ptr + (m + 0) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr2 = p.a_ptr + (m + 2) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr3 = p.a_ptr + (m + 4) * p.K_pack + k * 2;
hie::bfloat16* a_bf16_ptr4 = p.a_ptr + (m + 6) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr1 = p.b_ptr + (n + 0) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr2 = p.b_ptr + (n + 2) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr3 = p.b_ptr + (n + 4) * p.K_pack + k * 2;
hie::bfloat16* b_bf16_ptr4 = p.b_ptr + (n + 6) * p.K_pack + k * 2;
uint64_t c_fp16_ptr = reinterpret_cast<uint64_t>(p.c_ptr + (m + 0) * N + n);
int next_line_offset = N * sizeof(float16_t);
float* bias_ptr = p.bias_ptr + n;
int k_init = k * 2;
int K_MAX = (k + k_tile) * 2;
K_MAX = K_MAX < p.K_pack * 2 ? K_MAX : p.K_pack * 2;
int K_MAIN = K_MAX / 16 * 16;
activation_const_t constant;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"ptrue p4.b \n"
"ptrue p5.b \n"
// ASM_BLOCK_PREFETCH_PART_0
"mov x0, %[k_init] \n" // k
"mov x2, %[m] \n"
"mov x3, %[n] \n"
/* clear bfmmla result regs */
ASM_BLOCK_CLEAR_BFMMLA_REG
" " LABEL_FOR_LOOP_K ":\n"
/* load bf16 input & weight */
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
"add x0, x0, #16 \n" // k += 16
"cmp x0, %[K_MAIN] \n" // compare k and K_MAIN
"b.tstop " LABEL_FOR_LOOP_K "b \n" // if k < K_MAIN, go to label
/* load bf16 input & weight */
"mov x4, x0 \n"
"whilelt p5.h, x4, %[K_MAX] \n" // compare k and K_MAX
"add x4, x0, #8 \n"
"whilelt p4.h, x4, %[K_MAX] \n"
ASM_BLOCK_LOAD_A_RES
ASM_BLOCK_LOAD_B_RES
// ASM_BLOCK_PREFETCH_PART_1
/* matmul */
ASM_BLOCK_BFMMLA
/* reorder mmla output */
ASM_BLOCK_REORDER_BFMMLA_OUTPUT_FP16
"whilelt p1.h, x3, %[N] \n" // compare n, N
// "add x6, x3, #4 \n" // n + 4
// "whilelt p2.s, x6, %[N] \n" // compare n, N
: /* empty OutputOperands */
: [a_bf16_ptr1] "r"(a_bf16_ptr1), [a_bf16_ptr2] "r"(a_bf16_ptr2),
[a_bf16_ptr3] "r"(a_bf16_ptr3), [a_bf16_ptr4] "r"(a_bf16_ptr4),
[b_bf16_ptr1] "r"(b_bf16_ptr1), [b_bf16_ptr2] "r"(b_bf16_ptr2),
[b_bf16_ptr3] "r"(b_bf16_ptr3), [b_bf16_ptr4] "r"(b_bf16_ptr4),
[next_line_offset] "r"(next_line_offset),
[m] "r"(m), [n] "r"(n), [k_init] "r"(k_init),
[M] "r"(M), [N] "r"(N), [K_MAIN] "r"(K_MAIN), [K_MAX] "r"(K_MAX)
: "p0", "p1", "p2", "p4", "p5", "p3",
"x0", "x1", "x2", "x3", "x4", "x5", "x6", "x7", "x8",
"z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9",
"z10", "z11", "z12", "z13", "z14", "z15", "z16", "z17", "z18", "z19",
"z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29",
"z30", "z31",
"cc", "memory");
if (p.with_bias && k == 0) {
ASM_BLOCK_ADD_BIAS
}
if (LIKELY(k != 0)) {
ASM_BLOCK_C_RES_ACCUMULATE_FP16
}
if (p.do_act == 1) {
switch (p.actType) {
case UnaryType::UNARYTYPE_UNDEFINED: {
break;
}
case UnaryType::RELU: {
ASM_BLOCK_ACTIVE_RELU
break;
}
case UnaryType::SILU: {
ASM_BLOCK_ACTIVE_SILU
break;
}
case UnaryType::TANH: {
ASM_BLOCK_ACTIVE_TANH
break;
}
case UnaryType::GELU_ERF: {
ASM_BLOCK_ACTIVE_GELU_ERF
break;
}
case UnaryType::GELU_TANH: {
ASM_BLOCK_ACTIVE_GELU_TANH
break;
}
default:
break;
}
}
ASM_BLOCK_C_RES_STORE_FP16
// clang-format on
#undef LABEL_FOR_LOOP_K
#undef LABEL_SKIP_PRF
#undef LABEL_SKIP_STORE
#undef LABEL_SKIP_LD_A1
#undef LABEL_SKIP_LD_W1
#undef LABEL_SKIP_ACCUMULATE
return;
}
} // namespace rtp_llm