source/backend/cpu/BinaryUtils.hpp (448 lines of code) (raw):

#include <math.h> #include <algorithm> #include "compute/CommonOptFunction.h" #include "MNN_generated.h" template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryMax { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return std::max(x, y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryMin { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return std::min(x, y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryMul { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return x * y; } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryAdd { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return x + y; } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinarySub { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return x - y; } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryRealDiv { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return x / y; } }; /** Ref from onnxruntime/onnxruntime/core/providers/cpu/math/element_wise_ops.cc :: Modulus */ template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryModInt { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { auto res = x % y; if ((res < 0 && y > 0) || (res > 0 && y < 0)) { res += y; } return (_ErrorCode)res; } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryMod { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return fmodf(x, y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryGreater { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x > y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryLess { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x < y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryGreaterEqual { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x >= y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryLessEqual { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x <= y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryEqual { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x == y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryFloorDiv { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return floor(static_cast<double>(x) / y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryFloorMod { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return x - floor(x / y) * y; } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinarySquaredDifference { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (x - y) * (x - y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryPow { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return pow(x, y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryAtan2 { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return atan2(x, y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryLogicalOr { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x || y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryLogicalXor { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x ^ y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryNotEqual { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)((x != y) ? 1 : 0); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryLeftShift { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)(x << y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryBitwiseAnd { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)(x & y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryRightShift { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)(x >> y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryBitwiseOr { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)(x | y); } }; template <typename _Arg1, typename _Arg2, typename _ErrorCode> struct BinaryBitwiseXor { _ErrorCode operator()(const _Arg1& x, const _Arg2& y) const { return (_ErrorCode)(x ^ y); } }; template<typename Func, typename V, int pack, typename U, typename Tout> void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int needBroadcastIndex) { Func compute; const int sizeDivUnit = elementSize / pack; const int remainCount = elementSize - sizeDivUnit * pack; auto src0 = (const U*)(inputRaw0); auto src1 = (const U*)(inputRaw1); auto dst = (Tout*)outputRaw; if (-1 == needBroadcastIndex) { if (sizeDivUnit > 0) { int sizeDivC4 = sizeDivUnit / 4; int sizeDivUnitRemain = sizeDivUnit % 4; for (int i = 0; i < sizeDivC4; ++i) { V a0 = V::load(src0); V b0 = V::load(src1); V a1 = V::load(src0 + 1 * pack); V b1 = V::load(src1 + 1 * pack); V a2 = V::load(src0 + 2 * pack); V b2 = V::load(src1 + 2 * pack); V a3 = V::load(src0 + 3 * pack); V b3 = V::load(src1 + 3 * pack); V::save(dst, compute(a0, b0)); V::save(dst+1*pack, compute(a1, b1)); V::save(dst+2*pack, compute(a2, b2)); V::save(dst+3*pack, compute(a3, b3)); src0 += 4*pack; src1 += 4*pack; dst += 4*pack; } for (int i = 0; i < sizeDivUnitRemain; ++i) { V a = V::load(src0); V b = V::load(src1); V::save(dst, compute(a, b)); src0 += pack; src1 += pack; dst += pack; } } if (remainCount > 0) { U tempSrc0[pack]; U tempSrc1[pack]; Tout tempDst[pack]; ::memcpy(tempSrc0, src0, remainCount * sizeof(U)); ::memcpy(tempSrc1, src1, remainCount * sizeof(U)); V a = V::load(tempSrc0); V b = V::load(tempSrc1); V::save(tempDst, compute(a, b)); ::memcpy(dst, tempDst, remainCount * sizeof(U)); } } else if (0 == needBroadcastIndex) { const U srcValue0 = src0[0]; V a = V(srcValue0); if (sizeDivUnit > 0) { int sizeDivC4 = sizeDivUnit / 4; int sizeUnitRemain = sizeDivUnit % 4; for (int i = 0; i < sizeDivC4; ++i) { V b0 = V::load(src1); V b1 = V::load(src1 + 1*pack); V b2 = V::load(src1 + 2*pack); V b3 = V::load(src1 + 3*pack); V::save(dst, compute(a, b0)); V::save(dst+1*pack, compute(a, b1)); V::save(dst+2*pack, compute(a, b2)); V::save(dst+3*pack, compute(a, b3)); src1 += 4*pack; dst += 4*pack; } for (int i = 0; i < sizeUnitRemain; ++i) { V b = V::load(src1); V::save(dst, compute(a, b)); src1 += pack; dst += pack; } } if (remainCount > 0) { U tempSrc1[pack]; Tout tempDst[pack]; ::memcpy(tempSrc1, src1, remainCount * sizeof(U)); V b = V::load(tempSrc1); V::save(tempDst, compute(a, b)); ::memcpy(dst, tempDst, remainCount * sizeof(U)); } } else { const auto srcValue1 = static_cast<U>(src1[0]); V b = V(srcValue1); if (sizeDivUnit > 0) { int sizeDivC4 = sizeDivUnit / 4; int sizeUnitRemain = sizeDivUnit % 4; for (int i = 0; i < sizeDivC4; ++i) { const auto src0Ptr = src0; auto dstPtr = dst; V a0 = V::load(src0Ptr); V a1 = V::load(src0Ptr + 1*pack); V a2 = V::load(src0Ptr + 2*pack); V a3 = V::load(src0Ptr + 3*pack); V::save(dstPtr, compute(a0, b)); V::save(dstPtr+1*pack, compute(a1, b)); V::save(dstPtr+2*pack, compute(a2, b)); V::save(dstPtr+3*pack, compute(a3, b)); src0 += 4*pack; dst += 4*pack; } for (int i = 0; i < sizeUnitRemain; ++i) { const auto src0Ptr = src0; auto dstPtr = dst; V a = V::load(src0Ptr); V::save(dstPtr, compute(a, b)); src0 += pack; dst += pack; } } if (remainCount > 0) { U tempSrc0[pack]; Tout tempDst[pack]; ::memcpy(tempSrc0, src0, remainCount * sizeof(U)); V a = V::load(tempSrc0); V::save(tempDst, compute(a, b)); ::memcpy(dst, tempDst, remainCount * sizeof(U)); } } } template<typename Vec> struct VecBinaryAdd { Vec operator()(Vec& x, Vec& y) const { return x + y; } }; template<typename Vec> struct VecBinarySub { Vec operator()(Vec& x, Vec& y) const { return x - y; } }; template<typename Vec> struct VecBinaryMul { Vec operator()(Vec& x, Vec& y) const { return x * y; } }; template<typename Vec> struct VecBinaryMin { Vec operator()(Vec& x, Vec& y) const { return Vec::min(x, y); } }; template<typename Vec> struct VecBinaryMax { Vec operator()(Vec& x, Vec& y) const { return Vec::max(x, y); } }; template<typename Vec> struct VecBinarySqd { Vec operator()(Vec& x, Vec& y) const { return (x-y)*(x-y); } }; template<typename Vec> struct VecBinaryLess { Vec operator()(Vec& x, Vec& y) const { return x < y; } }; template<typename Vec> struct VecBinaryGreater { Vec operator()(Vec& x, Vec& y) const { return x > y; } }; template<typename Vec> struct VecBinaryLessEqual { Vec operator()(Vec& x, Vec& y) const { return x <= y; } }; template<typename Vec> struct VecBinaryGreaterEqual { Vec operator()(Vec& x, Vec& y) const { return x >= y; } }; template<typename Vec> struct VecBinaryEqual { Vec operator()(Vec& x, Vec& y) const { return x == y; } }; namespace MNN { template<typename Tin, typename Tout, typename Func> void execute(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex) { Func f; const int input0DataCount = elementSize; const int input1DataCount = elementSize; const Tin* input0Data = (const Tin*)inputRaw0; const Tin* input1Data = (const Tin*)inputRaw1; Tout* outputData = (Tout*)outputRaw; if (broadcastIndex == 0) { // data count == 1, not only mean scalar input, maybe of shape (1, 1, 1, ...,1) for (int i = 0; i < input1DataCount; i++) { outputData[i] = (Tout)(f(input0Data[0], input1Data[i])); } } else if (broadcastIndex == 1) { for (int i = 0; i < input0DataCount; i++) { outputData[i] = (Tout)(f(input0Data[i], input1Data[0])); } } else { // both input contains more than one element,which means no scalar input for (int i = 0; i < input0DataCount; i++) { outputData[i] = (Tout)(f(input0Data[i], input1Data[i])); } } } template<typename Tin, typename Tout, typename Func> void executeInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const QuanPrePostParameters* params, size_t elementSize, size_t needBroadcast) { Func f; int size = static_cast<int>(elementSize); #ifdef MNN_USE_NEON size *= 4; #endif float inp0 = 0, inp1 = 0, output = 0; #ifdef MNN_USE_SSE const int offset = 128; const uint8_t* inputData0 = (uint8_t*)inputRaw0; const uint8_t* inputData1 = (uint8_t*)inputRaw1; uint8_t* outputData = (uint8_t*)outputRaw; #else const int offset = 0; const int8_t* inputData0 = (int8_t*)inputRaw0; const int8_t* inputData1 = (int8_t*)inputRaw1; int8_t* outputData = (int8_t*)outputRaw; #endif const int maxValue = static_cast<int32_t>(params->maxValue) + offset; const int minValue = static_cast<int32_t>(params->minValue) + offset; for (int i = 0; i < size; ++i) { if (needBroadcast == 0) { inp0 = (inputData0[0]- offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; inp1 = (inputData1[i]- offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; output = f(inp0, inp1); } else if (needBroadcast == 1) { inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; inp1 = (inputData1[0] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; output = f(inp0, inp1); } else { inp0 = (inputData0[i] - offset - params->inputZeroPoint[0]) * inputScalesFp32[0]; inp1 = (inputData1[i] - offset - params->inputZeroPoint[1]) * inputScalesFp32[1]; output = f(inp0, inp1); } int value = (int)roundf(output * inputScalesFp32[2]) + offset + static_cast<int32_t>(params->outputZeroPoint[0]); if (value > maxValue) { value = maxValue; } if (value < minValue) { value = minValue; } outputData[i] = value; } } template<typename V, int pack, typename U> MNNBinaryExecute selectVector(int type) { switch (type) { case BinaryOpOperation_ADD: return executeVec<VecBinaryAdd<V>, V, pack, U, U>; case BinaryOpOperation_SUB: return executeVec<VecBinarySub<V>, V, pack, U, U>; case BinaryOpOperation_MUL: return executeVec<VecBinaryMul<V>, V, pack, U, U>; case BinaryOpOperation_MINIMUM: return executeVec<VecBinaryMin<V>, V, pack, U, U>; case BinaryOpOperation_MAXIMUM: return executeVec<VecBinaryMax<V>, V, pack, U, U>; case BinaryOpOperation_SquaredDifference: return executeVec<VecBinarySqd<V>, V, pack, U, U>; case BinaryOpOperation_LESS: return executeVec<VecBinaryLess<V>, V, pack, U, int32_t>; case BinaryOpOperation_LESS_EQUAL: return executeVec<VecBinaryLessEqual<V>, V, pack, U, int32_t>; case BinaryOpOperation_GREATER: return executeVec<VecBinaryGreater<V>, V, pack, U, int32_t>; case BinaryOpOperation_GREATER_EQUAL: return executeVec<VecBinaryGreaterEqual<V>, V, pack, U, int32_t>; case BinaryOpOperation_EQUAL: return executeVec<VecBinaryEqual<V>, V, pack, U, int32_t>; } return nullptr; } };