source/backend/cpu/x86_x64/FunctionDispatcher.cpp (172 lines of code) (raw):
//
// FunctionDispatcher.cpp
// MNN
//
// Created by MNN on 2019/08/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <limits>
#include "avx512/FunctionSummary.hpp"
#include "avx/FunctionSummary.hpp"
#include "AVX2Functions.hpp"
#include "avxfma/FunctionSummary.hpp"
#include "backend/cpu/compute/CommonOptFunction.h"
#include "backend/cpu/compute/ConvOpt.h"
#include "backend/cpu/compute/Int8FunctionsOpt.h"
#include "cpu_id.h"
#include "sse/FunctionSummary.hpp"
// https://stackoverflow.com/a/11230437
struct FunctionGroup {
int tileNumber = 8;
int eP = 12;
int lP = 1;
int hP = 4;
void (*MNNExpC8)(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) = _SSE_MNNExpC8;
void (*MNNSoftmax)(float* dest, const float* source, size_t size) = _SSE_MNNSoftmax;
void (*MNNReluInt8)(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) = _SSE_MNNReluInt8;
void (*MNNHardSwish)(float* dst, const float* src, size_t size) = _SSE_MNNHardSwish;
void (*MNNGelu)(float* dst, const float* src, size_t size, float* parameters) = _SSE_MNNGelu;
void (*MNNNorm)(float *dst, const float *src, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) = _SSE_MNNNorm;
};
static FunctionGroup gFunc;
void _SSEMNNGetMatMulPackMode(int* eP, int *lP, int* hP) {
*eP = gFunc.eP;
*lP = gFunc.lP;
*hP = gFunc.hP;
}
void MNNFunctionInit() {
auto cpuFlags = libyuv::InitCpuFlags();
#ifdef __EMSCRIPTEN__
// TODO: Find better way
cpuFlags |= libyuv::kCpuHasSSE41;
cpuFlags |= libyuv::kCpuHasSSSE3;
#endif
auto coreFunction = MNN::MNNGetCoreFunctions();
if (cpuFlags & libyuv::kCpuHasSSSE3) {
coreFunction->MNNGetMatMulPackMode = _SSEMNNGetMatMulPackMode;
coreFunction->MNNPackedMatMul = _SSE_MNNPackedMatMul;
coreFunction->MNNPackedMatMulRemain = _SSE_MNNPackedMatMulRemain;
#ifdef MNN_CPU_WEIGHT_DEQUANT_GEMM
coreFunction->MNNPackedMatMul_int8 = _SSE_MNNPackedMatMul_int8;
coreFunction->MNNPackedMatMulRemain_int8 = _SSE_MNNPackedMatMulRemain_int8;
#endif
#ifdef MNN_LOW_MEMORY
coreFunction->MNNAbsMax = _SSE_MNNAbsMaxFP32;
coreFunction->MNNDynamicQuant = _SSE_MNNDynamicQuant;
coreFunction->MNNAsyQuantInfo = _SSE_MNNAsyQuantInfo;
coreFunction->MNNAsyQuantFunc = _SSE_MNNAsyQuantFunc;
#endif
coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A;
coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B;
// Dynamic Quant
coreFunction->MNNCountMaxMinValue = _SSE_MNNCountMinMaxValue;
}
#ifdef MNN_USE_AVX
if (cpuFlags & libyuv::kCpuHasAVX2) {
MNN::AVX2Functions::init(cpuFlags);
gFunc.MNNExpC8 = _AVX_MNNExpC8;
gFunc.MNNSoftmax = _AVX_MNNSoftmax;
gFunc.MNNGelu = _AVX_MNNGelu;
if (cpuFlags & libyuv::kCpuHasFMA3) {
gFunc.MNNGelu = _AVX_MNNGeluFMA;
gFunc.MNNExpC8 = _AVX_MNNExpC8FMA;
}
gFunc.MNNNorm = _AVX_MNNNorm;
}
#endif
_SSE_ImageProcessInit(coreFunction, cpuFlags);
}
void MNNAvgPoolUint8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx, ssize_t paddingx, ssize_t factor) {
int pack = 16;
uint32_t f = static_cast<uint32_t>(factor);
uint8_t* dstPtr = reinterpret_cast<uint8_t*>(dst);
const uint8_t* srcPtr = reinterpret_cast<uint8_t*>(src);
for (int ox = 0; ox < outputWidth; ++ox) {
std::vector<uint32_t> sum_(pack, 0);
for (int y = 0; y < kernely; ++y) {
for (int x = 0; x < kernelx; ++x) {
const uint8_t *inputPtr = srcPtr + pack* (inputWidth* y + x);
for (int idx = 0; idx < pack; ++idx) {
sum_[idx] += *(inputPtr + idx);
}
}
}
for (int idx = 0; idx < pack; ++idx) {
*(dstPtr + idx) = static_cast<uint8_t>((sum_[idx] * f)>>24);
}
dstPtr = dstPtr + pack;
srcPtr = srcPtr + pack* stridesx;
}
}
void MNNMaxPoolInt8_(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx) {
int pack = 16;
int8_t* dstPtr = dst;
const int8_t* srcPtr = src;
for (int ox = 0; ox < outputWidth; ++ox){
std::vector<int8_t> results(pack, INT8_MIN);
for (int y = 0; y < kernely; ++y) {
for (int x = 0; x < kernelx; ++x) {
const int8_t* inputPtr = srcPtr + pack* (x + inputWidth* y);
for (int idx = 0; idx < pack; ++idx) {
results[idx] = std::max(results[idx], *(inputPtr + idx));
}
}
}
for (int idx = 0; idx < pack;++idx) {
*(dstPtr + idx) = results[idx];
}
dstPtr = dstPtr + pack;
srcPtr = srcPtr + pack* stridesx;
}
}
void MNNInt8FunctionInit() {
auto cpuFlags = libyuv::InitCpuFlags();
auto core = MNN::MNNGetInt8CoreFunctions();
core->MNNAvgPoolInt8 = MNNAvgPoolUint8;
core->MNNMaxPoolInt8 = MNNMaxPoolInt8_;
if (cpuFlags & libyuv::kCpuHasSSE41) {
core->MNNFloat2Int8 = _SSE_MNNFloat2Int8;
core->MNNInt8ScaleToFloat = _SSE_MNNInt8ScaleToFloat;
core->Int8GemmKernel = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
core->Int8GemmKernelFast = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit;
core->ConvDepthwiseLineInt8 = _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit;
#ifdef MNN_LOW_MEMORY
core->Int8GemmKernel_W4 = _SSE_MNNGemmInt8AddBiasScale_16x4_w4;
#endif
}
}
void _SSE_ImageProcessInit(void* functions, int cpuFlags) {
auto coreFunction = static_cast<MNN::CoreFunctions*>(functions);
coreFunction->MNNRGBAToBGRA = _SSE_MNNRGBAToBGRA;
coreFunction->MNNNV21ToRGBA = _SSE_MNNNV21ToRGBA;
coreFunction->MNNNV21ToRGB = _SSE_MNNNV21ToRGB;
coreFunction->MNNNV21ToBGRA = _SSE_MNNNV21ToBGRA;
coreFunction->MNNNV21ToBGR = _SSE_MNNNV21ToBGR;
//coreFunction->MNNsampleBilinearCommon = _SSE_sampleBilinearCommon;
if (cpuFlags & libyuv::kCpuHasSSE41) {
coreFunction->MNNC1ToFloatC1 = _SSE_MNNC1ToFloatC1;
coreFunction->MNNC3ToFloatC3 = _SSE_MNNC3ToFloatC3;
coreFunction->MNNC3ToFloatRGBA = _SSE_MNNC3ToFloatRGBA;
coreFunction->MNNSamplerC4Nearest = _SSE_MNNSamplerC4Nearest;
coreFunction->MNNSamplerC4Bilinear = _SSE_MNNSampleC4Bilinear;
}
}
// ========= CommonOptFunction.cpp ===========
void MNNCopyC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
_SSE_MNNCopyC4WithStride(source, dest, srcStride, dstStride, count);
}
void MNNAddC4WithStride(const float* source, float* dest, size_t srcStride, size_t dstStride, size_t count) {
_SSE_MNNAddC4WithStride(source, dest, srcStride, dstStride, count);
}
void MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) {
return _SSE_MNNReluWithSlopeChannel(dst, src, slope, sizeQuad, depthQuad);
}
void MNNReluInt8(int8_t* dst, const int8_t* src, size_t size, ssize_t zeroPoint) {
return gFunc.MNNReluInt8(dst, src, size, zeroPoint);
}
void MNNHardSwish(float* dst, const float* src, size_t size) {
return gFunc.MNNHardSwish(dst, src, size);
}
void MNNGelu(float* dst, const float* src, size_t size, float* parameters) {
return gFunc.MNNGelu(dst, src, size, parameters);
}
void MNNExpC8(float* dest, const float* source, float* offset, const float* parameters, size_t countC8) {
gFunc.MNNExpC8(dest, source, offset, parameters, countC8);
}
void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) {
_SSE_MNNInt8ToInt16(dest, source, count);
}
void MNNSoftmax(float* dest, const float* source, size_t size) {
gFunc.MNNSoftmax(dest, source, size);
}
void MNNNorm(float* dest, const float* source, const float *gamma, const float *beta, float epsilon, size_t size, bool RMSNorm) {
gFunc.MNNNorm(dest, source, gamma, beta, epsilon, size, RMSNorm);
}