maga_transformer/cpp/devices/arm_impl/gemm_opt/ArmGemmPacking.cc (767 lines of code) (raw):
#include <arm_sve.h>
#include <cstring>
// #define PACK_DEBUG
#ifdef PACK_DEBUG
#include <iomanip>
#endif
#include "ArmGemmKernel.h"
#include "gemm_microkernel_macro_m8_bf16.h"
#include "activation_const.hpp"
#include "arm_common.h"
#include "maga_transformer/cpp/core/Buffer.h"
#include "maga_transformer/cpp/devices/DeviceFactory.h"
#include "maga_transformer/cpp/devices/arm_impl/ArmDevice.h"
#include "maga_transformer/cpp/models_weight/W.h"
#include "maga_transformer/cpp/core/torch_utils/BufferTorchUtils.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p_bf16p_interface.h"
#include "kai/ukernels/matmul/matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla.h"
#include "kai/ukernels/matmul/pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon.h"
#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm.h"
#include "kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#define GPTQ_COMPUTE_AS_DI_BF16 0
namespace rtp_llm {
static const size_t kai_num_bytes_multiplier = sizeof(uint16_t);
static const size_t kai_bl = 32;
inline static size_t kai_num_blocks_per_row(size_t k, size_t bl) {
KAI_ASSUME((k % 2) == 0);
KAI_ASSUME(bl == kai_bl);
return kai_roundup(k, bl) / bl;
}
inline static size_t kai_num_bytes_per_block(size_t bl) {
KAI_ASSUME(bl == kai_bl);
return (bl / 2) + kai_num_bytes_multiplier;
}
inline static size_t kai_rhs_stride(size_t k, size_t bl) {
KAI_ASSUME(bl == kai_bl);
KAI_ASSUME((k % 2) == 0);
KAI_ASSUME((k % bl) == 0);
const size_t num_blocks_per_row = kai_num_blocks_per_row(k, bl);
const size_t num_bytes_per_block = kai_num_bytes_per_block(bl);
return num_bytes_per_block * num_blocks_per_row;
}
static inline size_t num_blocks_per_row(size_t k, size_t bl) {
return k / bl;
}
static inline size_t num_bytes_per_block_qs4c32(size_t bl) {
return (bl / 2) + sizeof(int16_t);
}
static void quant_qs4c32_f32(size_t n, size_t k, size_t bl, const float* rhs_f32, uint8_t* rhs_qs4c32) {
const size_t num_blocks_row = num_blocks_per_row(k, bl);
const size_t num_bytes_block = num_bytes_per_block_qs4c32(bl);
const size_t dst_stride = num_blocks_row * num_bytes_block;
#pragma omp parallel for
for (size_t row_idx = 0; row_idx < n; ++row_idx) {
const float* src_ptr = rhs_f32 + row_idx * k;
uint8_t* dst_ptr = (uint8_t*)rhs_qs4c32 + row_idx * dst_stride;
for (size_t block_idx = 0; block_idx < num_blocks_row; ++block_idx) {
float amax = 0.0f;
float max = 0.0f;
for (size_t b = 0; b < bl; ++b) {
const float src0_0 = src_ptr[block_idx * bl + b];
const float asrc0_0 = fabsf(src0_0);
if (amax < asrc0_0) {
amax = asrc0_0;
max = src0_0;
}
}
const float scale = max / -8.0;
const float recip_scale = scale ? 1.0f / scale : 0.0f;
// Store the scale at the beginning of the block
*((uint16_t*)dst_ptr) = kai_cast_f16_f32(scale);
dst_ptr += sizeof(uint16_t);
const size_t block_size = 32;
const size_t num_subblocks = bl / 32;
for (size_t subblock_idx = 0; subblock_idx < num_subblocks; ++subblock_idx) {
for (size_t i = 0; i < block_size / 2; ++i) {
const size_t src_base_addr = block_idx * bl + i + subblock_idx * block_size;
float v0_f32 = src_ptr[src_base_addr];
float v1_f32 = src_ptr[src_base_addr + block_size / 2];
v0_f32 *= recip_scale;
v1_f32 *= recip_scale;
const uint8_t v0_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v0_f32 + 8.5f));
const uint8_t v1_u8 = (uint8_t)std::min((int8_t)15, (int8_t)(v1_f32 + 8.5f));
const uint8_t rhs_v0 = (v1_u8 << 4) | v0_u8;
dst_ptr[0] = rhs_v0;
dst_ptr += sizeof(uint8_t);
}
}
}
}
}
ConstBufferPtr prepareGemmWeight(const std::string& key, ConstBufferPtr input) {
if (armPrepareWeightFunc == nullptr) {
if (std::getenv("ARM_GEMM_USE_KAI") == nullptr) {
armPrepareWeightFunc = prepareGemmOptWeight;
} else {
RTP_LLM_LOG_INFO("KleidiAI enabled.\n");
armPrepareWeightFunc = prepareKaiWeightBf16;
}
}
// Transpose and reorder
if (key == W::lm_head) {
return armPrepareWeightFunc(transposeWeight(input), true, true);
}
// // Reorder RHS weight matrics for better GEMM performance
if (key == W::attn_qkv_w) {
return armPrepareWeightFunc(input, false, true);
}
if (key == W::attn_o_w ||
key == W::ffn_w1 ||
key == W::ffn_w2 ||
key == W::ffn_w3) {
return armPrepareWeightFunc(input, false, false);
}
return input;
}
BufferPtr transposeWeight(ConstBufferPtr input) {
std::vector<size_t> Bshape = input->shape();
auto dim = input->dim();
size_t k;
size_t n;
k = Bshape[dim - 2];
n = Bshape[dim - 1];
arm_compute::NETranspose transB;
arm_compute::Tensor wei_tran_tensor;
arm_compute::TensorInfo wei_data_info;
arm_compute::TensorInfo wei_tran_info;
arm_compute::Tensor wei_tensor;
BufferPtr output;
auto data_type = input->type();
arm_compute::DataType acl_data_type;
if (data_type == DataType::TYPE_FP16)
acl_data_type = arm_compute::DataType::F16;
else if (data_type == DataType::TYPE_FP32)
acl_data_type = arm_compute::DataType::F32;
else
//printf("Not supported data type %d\n", data_type);
RTP_LLM_LOG_WARNING("Not supported data type %d\n", data_type);
wei_data_info = arm_compute::TensorInfo(arm_compute::TensorShape(n, k), 1, acl_data_type);
wei_tran_info = arm_compute::TensorInfo(arm_compute::TensorShape(k, n), 1, acl_data_type);
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k});
size_t element_num = k * n;
size_t data_size = data_type == DataType::TYPE_FP32 ? sizeof(float) : sizeof(float16_t);
//const void *data = malloc(element_num * data_size);
//output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
// data_type,
// weight_workspace_shape,
// data)),
size_t transposed_size = element_num * data_size;
void *transposed_data = malloc(transposed_size);
wei_tensor.allocator()->init(wei_data_info);
wei_tran_tensor.allocator()->init(wei_tran_info);
wei_tensor.allocator()->import_memory(input->data());
//wei_tran_tensor.allocator()->import_memory(output->data());
wei_tran_tensor.allocator()->import_memory(transposed_data);
transB.configure(&wei_tensor, &wei_tran_tensor);
transB.run();
//return output;
// Update input buffer with transposed data, reduce memory usage
RTP_LLM_CHECK_WITH_INFO(input->sizeBytes() >= transposed_size, "transpose dst size < src size");
memcpy(input->data(), transposed_data, transposed_size);
free(transposed_data);
auto packedBuffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
data_type,
weight_workspace_shape,
input->data()));
return packedBuffer;
}
//BufferPtr prepareGemmOptWeight(ConstBufferPtr input, bool isTranspose) {
// BufferPtr weight_workspace;
ConstBufferPtr prepareKaiWeightBf16(ConstBufferPtr input, bool isTranspose, bool isForceF32Out) {
ConstBufferPtr output = input;
std::vector<size_t> Bshape = input->shape();
auto dim = input->dim();
size_t k;
size_t n;
k = Bshape[dim - 2];
n = Bshape[dim - 1];
if (input->type() == DataType::TYPE_FP32) {
const size_t nr = kai_get_nr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t kr = kai_get_kr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t sr = kai_get_sr_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
// In a single row, we pack nr bias values followed by K rows of nr RHS values
const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n, k, nr, kr);
uint8_t* rhs_packed = new uint8_t[rhs_packed_size];
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
if (isTranspose)
weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k});
else
weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n});
output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_BF16,
weight_workspace_shape,
rhs_packed));
float* bias = new float[n];
memset(bias, 0, sizeof(float) * n);
const size_t rhs_stride = n * sizeof(float);
float* rhs = (float* )input->data();
// Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
int n_step = nr;
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t rhs_offset = kai_get_rhs_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n_start);
const size_t bias_offset = kai_get_bias_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n_start);
const size_t packed_offset = kai_get_rhs_packed_offset_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(n_start, k, nr, kr);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
kai_run_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon(
1, tile_n, k, nr, kr, sr, // Packing arguments
rhs_stride, // RHS stride
((uint8_t*)rhs + rhs_offset), // RHS
((uint8_t*)bias + bias_offset), // Bias
NULL, // Scale
(rhs_packed + packed_offset), // RHS packed
0, NULL);
}
delete[] bias;
return output;
} else if (input->type() == DataType::TYPE_FP16) {
const size_t nr = kai_get_nr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t kr = kai_get_kr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
const size_t sr = kai_get_sr_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla();
// In a single row, we pack nr bias values followed by K rows of nr RHS values
const size_t rhs_packed_size = isForceF32Out? kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n, k) :
kai_get_rhs_packed_size_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n, k);
uint8_t* rhs_packed = new uint8_t[rhs_packed_size];
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
if (isTranspose)
weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k});
else
weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n});
output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_BF16,
weight_workspace_shape,
rhs_packed));
const size_t rhs_stride = n * sizeof(float16_t);
float16_t* rhs = (float16_t* )input->data();
// Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
int n_step = n;
if (isForceF32Out) {
float* bias = new float[n];
memset(bias, 0, sizeof(float) * n);
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n_start);
const size_t bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n_start);
const size_t packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(n_start, k);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
kai_run_rhs_pack_kxn_bf16p12x4biasf32_f16_neon(
1, tile_n, k, nr, kr, sr, // Packing arguments
rhs_stride, // RHS stride
((uint8_t*)rhs + rhs_offset), // RHS
((uint8_t*)bias + bias_offset), // Bias
NULL, // Scale
(rhs_packed + packed_offset), // RHS packed
0, NULL);
}
delete[] bias;
} else {
float16_t* bias = new float16_t[n];
memset(bias, 0, sizeof(float16_t) * n);
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t rhs_offset = kai_get_rhs_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n_start);
const size_t bias_offset = kai_get_bias_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n_start);
const size_t packed_offset = kai_get_rhs_packed_offset_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(n_start, k);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
kai_run_rhs_pack_kxn_bf16p12x4biasf16_f16_neon(
1, tile_n, k, nr, kr, sr, // Packing arguments
rhs_stride, // RHS stride
((uint8_t*)rhs + rhs_offset), // RHS
((uint8_t*)bias + bias_offset), // Bias
NULL, // Scale
(rhs_packed + packed_offset), // RHS packed
0, NULL);
}
delete[] bias;
}
return output;
}
return output;
}
ConstBufferPtr prepareGemmOptWeight(ConstBufferPtr input, bool isTranspose, bool unused) {
ConstBufferPtr weight_workspace = input;
GemmKernel gemm_kernel;
std::vector<size_t> Bshape = input->shape();
auto dim = input->dim();
size_t k;
size_t n;
k = Bshape[dim - 2];
n = Bshape[dim - 1];
size_t batch_size = std::accumulate(Bshape.begin(), Bshape.end() - 2, (size_t)1, std::multiplies<size_t>());
size_t weight_k_pack = std::ceil(k / 8.0) * 8;
size_t width = weight_k_pack * 2;
size_t height = n / 2 + n % 2;
if (input->type() == DataType::TYPE_FP32 || input->type() == DataType::TYPE_FP16) {
// allocate a temp workspace to pack weight fp32->bf16
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
if (isTranspose)
weight_workspace_shape.insert(weight_workspace_shape.end(), {n, k});
else
weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n});
// weight_workspace = device->allocateBuffer({DataType::TYPE_BF16, weight_workspace_shape, AllocationType::DEVICE}, {"gemm_weight_workspace"});
size_t element_num = std::accumulate(Bshape.begin(), Bshape.end(), (size_t)1, std::multiplies<size_t>());
const void *data = malloc(element_num * sizeof(hie::bfloat16));
weight_workspace = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_BF16,
weight_workspace_shape,
data)),
memset(weight_workspace->data(), 0, weight_workspace->sizeBytes());
// pack weight
for (size_t batch = 0; batch < batch_size; ++batch) {
hie::bfloat16* weight_workspace_cur_ptr = reinterpret_cast<hie::bfloat16*>(weight_workspace->dataWithOffset(batch * height * width));
if (input->type() == DataType::TYPE_FP32) {
float* B_fp32_ptr = reinterpret_cast<float*>(input->dataWithOffset(batch * k * n));
gemm_kernel.gemm_pack_weight_FP32toBF16_arm(n, k, weight_k_pack, B_fp32_ptr, weight_workspace_cur_ptr);
} else { // if(params.B.type() == DataType::TYPE_FP16)
float16_t* B_fp16_ptr = reinterpret_cast<float16_t*>(input->dataWithOffset(batch * k * n));
gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, B_fp16_ptr, weight_workspace_cur_ptr);
}
}
// Update original buffer with packed data to save memory usage
//RTP_LLM_CHECK_WITH_INFO(input->sizeBytes() >= weight_workspace->sizeBytes(), "gemm pack dst size < src size");
//memcpy(input->data(), weight_workspace->data(), weight_workspace->sizeBytes());
//free(weight_workspace->data());
auto packedBuffer = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_BF16,
weight_workspace_shape,
//input->data()));
weight_workspace->data()));
return packedBuffer;
}
return weight_workspace;
}
//ConstBufferPtr prepareGemmWeight(const std::string& key, ConstBufferPtr input) {
// // Transpose and reorder
// if (key == W::lm_head) {
// return prepareGemmOptWeight(transposeWeight(input), true);
// }
//
// // Reorder RHS weight matrics for better GEMM performance
// if (key == W::attn_qkv_w ||
torch::Tensor ArmCpuDevice::preprocessGemmWeightByKey(const std::string& key, torch::Tensor weight) {
auto buffer = torchTensor2Buffer(weight);
auto retBuffer = prepareGemmWeight(key, buffer);
// Repacked buffer size may not match with shape size * element size,
// should use buffer pointer instead of copying data.
if ((key == W::attn_qkv_w ||
key == W::attn_o_w ||
key == W::ffn_w1 ||
key == W::ffn_w2 ||
// key == W::ffn_w3) {
//return prepareGemmOptWeight(input, false);
key == W::ffn_w3 ||
key == W::lm_head) && retBuffer->type() == DataType::TYPE_BF16) {
return Buffer2torchTensor(*retBuffer, false);
}
if ((key == W::attn_qkv_w ||
key == W::attn_o_w ||
key == W::ffn_w1 ||
key == W::ffn_w2 ||
key == W::ffn_w3) && retBuffer->type() == DataType::TYPE_UINT8) {
return Buffer2torchTensor(*retBuffer, false);
}
return Buffer2torchTensor(*retBuffer);
}
//torch::Tensor ArmCpuDevice::preprocessGemmWeightByKey(const std::string& key, torch::Tensor weight) {
// auto buffer = torchTensor2Buffer(weight);
// auto retBuffer = prepareGemmWeight(key, buffer);
// return Buffer2torchTensor(*retBuffer);
//}
ConstBufferPtr prepareGemmOptForGPTQInt4(ConstBufferPtr kernel, ConstBufferPtr scales, const std::string& key) {
ConstBufferPtr weight_workspace = kernel;
std::vector<size_t> Bshape = kernel->shape();
auto dim = kernel->dim();
size_t k;
size_t n;
k = Bshape[dim - 2];
n = Bshape[dim - 1];
n *= 2;
#if GPTQ_COMPUTE_AS_DI_BF16
GemmKernel gemm_kernel;
size_t weight_k_pack = std::ceil(k / 8.0) * 8;
if (kernel->type() == DataType::TYPE_INT8 && scales->type() == DataType::TYPE_FP16) {
int8_t* qweight = (int8_t*)kernel->data();
auto qscales = (__fp16*)scales->data();
__fp16* unpacked_weight = (__fp16*)malloc(k * n * 2);
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j += 2) {
int8_t qint8 = qweight[i * (n / 2) + j / 2];
__fp16 scale_0 = qscales[i / 128 * n + j ];
__fp16 scale_1 = qscales[i / 128 * n + j + 1];
auto elt_0 = qint8 & 0x0F;
auto elt_1 = (qint8 >> 4) & 0x0F;
if (elt_0 & 0x08) {
elt_0 -= 16;
}
if (elt_1 & 0x08) {
elt_1 -= 16;
}
auto x0 = scale_0 * elt_0;
auto x1 = scale_1 * elt_1;
unpacked_weight[i * n + j ] = x0;
unpacked_weight[i * n + j + 1] = x1;
}
}
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n});
size_t element_num = std::accumulate(Bshape.begin(), Bshape.end(), (size_t)1, std::multiplies<size_t>());
element_num *= 2;
const void *data = malloc(element_num * sizeof(hie::bfloat16));
weight_workspace = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_BF16,
weight_workspace_shape,
data)),
memset(weight_workspace->data(), 0, weight_workspace->sizeBytes());
hie::bfloat16* weight_workspace_cur_ptr = reinterpret_cast<hie::bfloat16*>(weight_workspace->data());
gemm_kernel.gemm_pack_weight_FP16toBF16_arm(n, k, weight_k_pack, unpacked_weight, weight_workspace_cur_ptr);
free(unpacked_weight);
return weight_workspace;
#else
if (kernel->type() == DataType::TYPE_INT8 && scales->type() == DataType::TYPE_FP16) {
int8_t* qweight = (int8_t*)kernel->data();
auto qscales = (__fp16*)scales->data();
float* unpacked_weight = (float*)malloc(k * n * sizeof(float));
#pragma omp parallel for collapse(2)
for (int i = 0; i < k; i++) {
for (int j = 0; j < n; j += 2) {
int8_t qint8 = qweight[i * (n / 2) + j / 2];
__fp16 scale_0 = qscales[i / 128 * n + j ];
__fp16 scale_1 = qscales[i / 128 * n + j + 1];
auto elt_0 = qint8 & 0x0F;
auto elt_1 = (qint8 >> 4) & 0x0F;
if (elt_0 & 0x08) {
elt_0 -= 16;
}
if (elt_1 & 0x08) {
elt_1 -= 16;
}
auto x0 = scale_0 * elt_0;
auto x1 = scale_1 * elt_1;
unpacked_weight[i * n + j ] = x0;
unpacked_weight[i * n + j + 1] = x1;
}
}
std::vector<size_t> input_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
input_shape.insert(input_shape.end(), {k, n});
BufferPtr input = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_FP32,
input_shape,
unpacked_weight));
auto transposedWeight = transposeWeight(input);
const size_t bl = 32;
const size_t num_blocks = k / bl;
const size_t num_bytes_per_block_qs4c32 = (bl / 2) + sizeof(int16_t);
const size_t rhs_native_size_qs4c32 = n * num_blocks * num_bytes_per_block_qs4c32;
const size_t nr = kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();
const size_t kr = kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();
const size_t sr = kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod();
// In a single row, we pack nr bias values followed by K rows of nr RHS values
const size_t rhs_packed_size = kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n, k, nr, kr, bl);
uint8_t* rhs_packed_mtx_qs4c32 = new uint8_t[rhs_packed_size];
std::vector<size_t> weight_workspace_shape = std::vector<size_t>(Bshape.begin(), Bshape.end() - 2);
weight_workspace_shape.insert(weight_workspace_shape.end(), {k, n / 2});
BufferPtr output = BufferPtr(new Buffer(MemoryType::MEMORY_CPU,
DataType::TYPE_UINT8,
weight_workspace_shape,
rhs_packed_mtx_qs4c32));
uint8_t* rhs_native_mtx_qs4c32 = new uint8_t[rhs_native_size_qs4c32];
quant_qs4c32_f32(
n, k, bl, (const float*)transposedWeight->data(), (uint8_t*)rhs_native_mtx_qs4c32);
struct kai_rhs_pack_qs4cxs1s0_param kai_rhs_params;
kai_rhs_params.lhs_zero_point = 1;
kai_rhs_params.rhs_zero_point = 8;
// Packing only needs to be performed once if the contents of the bias and RHS matrices are expected to be constant.
int n_step = 32;
size_t rhs_stride = kai_rhs_stride(k, bl);
#pragma omp parallel for
for (int n_start = 0; n_start < n; n_start += n_step) {
const size_t rhs_offset = kai_get_rhs_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, rhs_stride);
const size_t packed_offset = kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(n_start, k, nr, kr, bl);
int tile_n = (n_start + n_step <= n) ? n_step : n - n_start;
kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0(
1, tile_n, k, // Dimensions
nr, kr, sr, // Packing arguments
bl, // Block length
(const uint8_t*)(rhs_native_mtx_qs4c32 + rhs_offset), // RHS
NULL, // Bias
((uint8_t*)rhs_packed_mtx_qs4c32 + packed_offset), // RHS packed
0, &kai_rhs_params
);
}
delete[] rhs_native_mtx_qs4c32;
free(unpacked_weight);
return output;
#endif
}
return weight_workspace;
}
void GemmKernel::pack_input_arm(int M, int N, int K, int lda, int K_pack, float* a_fp32, hie::bfloat16* a_bf16) {
pack_input_impl_parallel_simd(M, N, K, lda, K_pack, a_fp32, a_bf16);
return;
}
void GemmKernel::gemm_pack_weight_FP32toBF16_arm(int N, int K, int K_pack, const float* b_fp32, hie::bfloat16* b_bf16) {
int k_tile = 1024; // empirical var: 1024, 5120
int k_thread = std::ceil(K_pack * 1.0 / k_tile);
parallel_for(k_thread, [&](int k) {
for (int n = 0; n < N; n += 2) {
float* b_fp32_ptr1 = (float*)b_fp32 + k * k_tile * N + n + 0;
float* b_fp32_ptr2 = (float*)b_fp32 + k * k_tile * N + n + 1;
hie::bfloat16* b_bf16_ptr = b_bf16 + n * K_pack + k * k_tile * 2; // [n, k*k_tile*2]
int kk_max = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
for (int kk = k * k_tile; kk < kk_max; kk += 4) {
for (int i = 0; i < 4 && (kk + i < kk_max); i++) {
b_bf16_ptr[i] = b_fp32_ptr1[i * N];
if (n != (N - 1)) {
b_bf16_ptr[i + 4] = b_fp32_ptr2[i * N];
}
}
b_bf16_ptr += 8;
b_fp32_ptr1 += 4 * N;
b_fp32_ptr2 += 4 * N;
}
}
});
#ifdef PACK_DEBUG
for (int i = 0; i < N; i++) {
for (int j = 0; j < K; j++) {
if (j % 8 == 0) {
printf("\n");
}
printf("%f ", b_fp32[j * N + i]);
}
printf("\n");
printf("\n");
}
printf("\n");
auto N_aligned = N / 2 + (N % 2);
for (int i = 0; i < N_aligned; i++) {
for (int j = 0; j < K_pack * 2; j++) {
if (j % 8 == 0) {
printf("\n");
}
std::cout << std::setiosflags(std::ios::fixed) << std::setprecision(6) << b_bf16[i * K_pack * 2 + j] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
#endif
return;
}
void GemmKernel::gemm_pack_weight_FP16toBF16_arm(int N, int K, int K_pack, const float16_t* b_fp16, hie::bfloat16* b_bf16) {
int k_tile = 1024; // empirical var: 1024, 5120
int k_thread = std::ceil(K_pack * 1.0 / k_tile);
parallel_for(k_thread, [&](int k) {
for (int n = 0; n < N; n += 2) {
float16_t* b_fp16_ptr1 = (float16_t*)b_fp16 + k * k_tile * N + n + 0;
float16_t* b_fp16_ptr2 = (float16_t*)b_fp16 + k * k_tile * N + n + 1;
hie::bfloat16* b_bf16_ptr = b_bf16 + n * K_pack + k * k_tile * 2; // [n, k*k_tile*2]
int kk_max = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
for (int kk = k * k_tile; kk < kk_max; kk += 4) {
for (int i = 0; i < 4 && (kk + i < kk_max); i++) {
b_bf16_ptr[i] = b_fp16_ptr1[i * N];
if (n != (N - 1)) {
b_bf16_ptr[i + 4] = b_fp16_ptr2[i * N];
}
}
b_bf16_ptr += 8;
b_fp16_ptr1 += 4 * N;
b_fp16_ptr2 += 4 * N;
}
}
});
#ifdef PACK_DEBUG
for (int i = 0; i < N; i++) {
for (int j = 0; j < K; j++) {
if (j % 8 == 0) {
printf("\n");
}
std::cout << b_fp16[j * N + i] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
auto N_aligned = N / 2 + (N % 2);
for (int i = 0; i < N_aligned; i++) {
for (int j = 0; j < K_pack * 2; j++) {
if (j % 8 == 0) {
printf("\n");
}
std::cout << std::setiosflags(std::ios::fixed) << std::setprecision(6) << b_bf16[i * K_pack * 2 + j] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
#endif
return;
}
void GemmKernel::pack_input_fp16tobf16_impl_parallel_simd(
int M, int N, int K, int lda, int K_pack, float16_t* a_fp16, hie::bfloat16* a_bf16) {
#define LABEL_FOR_LOOP_M "0"
#define LABEL_FOR_LOOP_K "1"
#define LABEL_m_EQ_M_1 "2"
int k_tile = 1024; // empirical var: 1024, 5120
int k_thread = std::ceil(K * 1.0 / k_tile);
// printf("k_tile: %d, k_thread: %d\n", k_tile, k_thread);
// fp32 [ a[i, j+0], a[i, j+1], a[i, j+2], a[i, j+3] ]
// fp32 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3] ]
// bf16 [ a[i, j+0], a[i, j+1], a[i, j+2], a[i, j+3],
// a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3]]
parallel_for(k_thread, [&](int k) {
float16_t* a_fp16_ptr1 = a_fp16 + 0 * lda + k * k_tile;
float16_t* a_fp16_ptr2 = a_fp16 + 1 * lda + k * k_tile;
hie::bfloat16* a_bf16_ptr = a_bf16 + k * k_tile * 2;
int a_fp16_offset = 2 * lda * sizeof(float16_t);
int a_bf16_offset = 2 * K_pack * sizeof(hie::bfloat16); // if K_pack % 16 == 8, for the remain 8 zero elements, use next line to cover it
int kk = k * k_tile;
int kk_max = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"sub x1, %[M], #1 \n" // M - 1
"mov x2, #0 \n" // m
"" LABEL_FOR_LOOP_M
":\n"
"mov x3, %[a_fp16_ptr1] \n"
"mov x4, %[a_fp16_ptr2] \n"
"mov x5, %[a_bf16_ptr] \n"
"prfw pldl1strm, p0, [x3, #0, MUL VL] \n" // prefetch
"prfw pldl1strm, p0, [x4, #0, MUL VL] \n"
"mov x0, %[kk] \n"
"whilelt p1.h, x0, %[kk_max] \n" // compare kk
// and kk_max
"" LABEL_FOR_LOOP_K
":\n"
"ld1h z4.h, p1/z, [x3, #0, MUL VL] \n" // load 8 fp16
"dup z6.h, #0 \n"
"zip1 z0.h, z4.h, z6.h \n" // zip 4(or less) fp16 values with 0
"zip2 z1.h, z4.h, z6.h \n" // zip 4(or less) fp16 values with 0
"fcvt z0.s, p0/m, z0.h \n" // fp16 -> fp32
"dup z2.h, #0 \n"
"fcvt z1.s, p0/m, z1.h \n" // fp16 -> fp32
"dup z3.h, #0 \n"
"cmp x2, x1 \n" // compare m,
// M - 1
"b.none " LABEL_m_EQ_M_1
"f \n"
"ld1h z5.h, p1/z, [x4, #0, MUL VL] \n" // load, when
// m != M - 1
"zip1 z2.h, z5.h, z6.h \n" // zip 4(or less) fp16 values with 0
"zip2 z3.h, z5.h, z6.h \n" // zip 4(or less) fp16 values with 0
"fcvt z2.s, p0/m, z2.h \n" // fp16 -> fp32
"fcvt z3.s, p0/m, z3.h \n" // fp16 -> fp32
"" LABEL_m_EQ_M_1
":\n"
"add x3, x3, #16 \n" // a_fp16_ptr1 += 8
"add x4, x4, #16 \n" // a_fp16_ptr2 += 8
// "add x3, x3, #8 \n" // a_fp16_ptr1 += 4
// "add x4, x4, #8 \n" // a_fp16_ptr2 += 4
"prfw pldl1strm, p0, [x3, #0, MUL VL] \n"
"prfw pldl1strm, p0, [x4, #0, MUL VL] \n"
"bfcvt z0.h, p0/m, z0.s \n" // fp32 ->
// bf16
"bfcvt z1.h, p0/m, z1.s \n"
"bfcvt z2.h, p0/m, z2.s \n"
"bfcvt z3.h, p0/m, z3.s \n"
"uzp1 z4.h, z0.h, z2.h \n" // combine
// bf16
"uzp1 z5.h, z1.h, z3.h \n" // combine bf16
"zip1 p3.d, p1.d, p1.d \n" // cp 4 least significant half to 4 most significant half
""
"st1h z4.h, p3, [x5, #0, MUL VL] \n" // store bf16 data
"zip2 p3.d, p1.d, p1.d \n" // cp 4 most significant half to 4 least significant half
"st1h z5.h, p3, [x5, #1, MUL VL] \n" // store bf16
"add x5, x5, #32 \n" // a_bf16_ptr += 16
// "add x5, x5, #16 \n" // a_bf16_ptr += 8
// "prfw pstl1keep, p0, [x5, #0, MUL VL] \n"
"add x0, x0, #8 \n" // kk += 8
// "add x0, x0, #4 \n" // kk += 4
"whilelt p1.h, x0, %[kk_max] \n" // compare kk
// and kk_max
"b.tstop " LABEL_FOR_LOOP_K
"b \n" // if k < K_MAX, go to label
"add %[a_fp16_ptr1], %[a_fp16_ptr1], %[a_fp16_offset] \n"
"add %[a_fp16_ptr2], %[a_fp16_ptr2], %[a_fp16_offset] \n"
"add %[a_bf16_ptr], %[a_bf16_ptr], %[a_bf16_offset] \n"
"add x2, x2, #2 \n" // m += 2
"cmp x2, %[M] \n" // compare m,
// M
"b.tstop " LABEL_FOR_LOOP_M
"b \n" // if m < M, go to label
: /* empty OutputOperands */
: [a_fp16_ptr1] "r"(a_fp16_ptr1), [a_fp16_ptr2] "r"(a_fp16_ptr2),
[a_bf16_ptr] "r"(a_bf16_ptr), [kk] "r"(kk), [kk_max] "r"(kk_max),
[M] "r"(M), [a_fp16_offset] "r"(a_fp16_offset),
[a_bf16_offset] "r"(a_bf16_offset)
: "x0", "x1", "x2", "x3", "x4", "x5",
"p0", "p1", "p2", "p3",
"z0", "z1", "z2", "z3", "z4", "z5", "z6",
"cc", "memory");
// clang-format on
});
#ifdef PACK_DEBUG
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
if (j % 8 == 0) {
printf("\n");
}
printf("%f ", a_fp16[i * lda + j]);
// std::cout << a_fp16[i * lda + j] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
// int k_pack_compute = std::ceil(K / 16.0) * 16;
auto M_aligned = M + (M % 2);
for (int i = 0; i < M_aligned / 2; i++) {
for (int j = 0; j < K_pack * 2; j++) {
if (j % 8 == 0) {
printf("\n");
}
std::cout << a_bf16[i * K_pack * 2 + j] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
#endif
return;
}
void GemmKernel::pack_input_impl_parallel_simd(
int M, int N, int K, int lda, int K_pack, float* a_fp32, hie::bfloat16* a_bf16) {
#define LABEL_FOR_LOOP_M "0"
#define LABEL_FOR_LOOP_K "1"
#define LABEL_m_EQ_M_1 "2"
int k_tile = 1024; // empirical var: 1024, 5120
int k_thread = std::ceil(K * 1.0 / k_tile);
// printf("k_tile: %d, k_thread: %d\n", k_tile, k_thread);
// fp32 [ a[i, j+0], a[i, j+1], a[i, j+2], a[i, j+3] ]
// fp32 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3] ]
// bf16 [ a[i+1,j+0], a[i+1,j+1], a[i+1,j+2], a[i+1,j+3],
// a[i, j+0], a[i, j+1], a[i, j+2], a[i, j+3]] ???
parallel_for(k_thread, [&](int k) {
float* a_fp32_ptr1 = a_fp32 + 0 * lda + k * k_tile;
float* a_fp32_ptr2 = a_fp32 + 1 * lda + k * k_tile;
hie::bfloat16* a_bf16_ptr = a_bf16 + k * k_tile * 2;
int a_fp32_offset = 2 * lda * sizeof(float);
int a_bf16_offset = 2 * K_pack * sizeof(hie::bfloat16);
int kk = k * k_tile;
int kk_max = (k + 1) * k_tile < K ? (k + 1) * k_tile : K;
// clang-format off
asm volatile(
"ptrue p0.b \n"
"sub x1, %[M], #1 \n" // M - 1
"mov x2, #0 \n" // m
"" LABEL_FOR_LOOP_M
":\n"
"mov x3, %[a_fp32_ptr1] \n"
"mov x4, %[a_fp32_ptr2] \n"
"mov x5, %[a_bf16_ptr] \n"
"prfw pldl1strm, p0, [x3, #0, MUL VL] \n" // prefetch
"prfw pldl1strm, p0, [x4, #0, MUL VL] \n"
"mov x0, %[kk] \n"
"whilelt p1.s, x0, %[kk_max] \n" // compare kk
// and kk_max
"" LABEL_FOR_LOOP_K
":\n"
"ld1w z0.s, p1/z, [x3, #0, MUL VL] \n"
"dup z1.h, #0 \n"
"cmp x2, x1 \n" // compare m,
// M - 1
"b.none " LABEL_m_EQ_M_1
"f \n"
"ld1w z1.s, p1/z, [x4, #0, MUL VL] \n" // load, when
// m != M - 1
"" LABEL_m_EQ_M_1
":\n"
"add x3, x3, #16 \n"
"add x4, x4, #16 \n"
"prfw pldl1strm, p0, [x3, #0, MUL VL] \n"
"prfw pldl1strm, p0, [x4, #0, MUL VL] \n"
"bfcvt z0.h, p0/m, z0.s \n" // fp32 ->
// bf16
"bfcvt z1.h, p0/m, z1.s \n"
"uzp1 z2.h, z0.h, z1.h \n" // combine
// bf16
"uzp1 p3.h, p1.h, p1.h \n"
"st1h z2.h, p3, [x5, #0, MUL VL] \n" // store bf16
// data
"add x5, x5, #16 \n"
// "prfw pstl1keep, p0, [x5, #0, MUL VL] \n"
"add x0, x0, #4 \n" // kk += 4
"whilelt p1.s, x0, %[kk_max] \n" // compare kk
// and kk_max
"b.tstop " LABEL_FOR_LOOP_K
"b \n" // if k < K_MAX, go to label
"add %[a_fp32_ptr1], %[a_fp32_ptr1], %[a_fp32_offset] \n"
"add %[a_fp32_ptr2], %[a_fp32_ptr2], %[a_fp32_offset] \n"
"add %[a_bf16_ptr], %[a_bf16_ptr], %[a_bf16_offset] \n"
"add x2, x2, #2 \n" // m += 2
"cmp x2, %[M] \n" // compare m,
// M
"b.tstop " LABEL_FOR_LOOP_M
"b \n" // if m < M, go to label
: /* empty OutputOperands */
: [a_fp32_ptr1] "r"(a_fp32_ptr1), [a_fp32_ptr2] "r"(a_fp32_ptr2),
[a_bf16_ptr] "r"(a_bf16_ptr), [kk] "r"(kk), [kk_max] "r"(kk_max),
[M] "r"(M), [a_fp32_offset] "r"(a_fp32_offset),
[a_bf16_offset] "r"(a_bf16_offset)
: "x0", "x1", "x2", "x3", "x4", "x5", "p0", "p1", "p2", "p3", "z0",
"z1", "z2", "cc", "memory");
// clang-format on
});
#ifdef PACK_DEBUG
for (int i = 0; i < M; i++) {
for (int j = 0; j < K; j++) {
if (j % 8 == 0) {
printf("\n");
}
printf("%f ", a_fp32[i * lda + j]);
}
printf("\n");
printf("\n");
}
printf("\n");
auto M_aligned = M + (M % 2);
for (int i = 0; i < M_aligned / 2; i++) {
for (int j = 0; j < K_pack * 2; j++) {
if (j % 8 == 0) {
printf("\n");
}
std::cout << a_bf16[i * K_pack * 2 + j] << " ";
}
printf("\n");
printf("\n");
}
printf("\n");
#endif
return;
}
} // namespace rtp_llm