source/backend/cpu/compute/StrassenMatmulComputor.hpp (36 lines of code) (raw):

// // StrassenMatmulComputor.hpp // MNN // // Created by MNN on 2019/02/11. // Copyright © 2018, Alibaba Group Holding Limited // #ifndef StrassenMatmulComputor_hpp #define StrassenMatmulComputor_hpp #include <functional> #include "core/BufferAllocator.hpp" #include "core/Backend.hpp" namespace MNN { /** Based on Boyer, B., Dumas, J.-G., Pernet, C., & Zhou, W. (2007). Memory efficient scheduling of Strassen-Winogradʼs matrix multiplication algorithm. Proceedings of the 2009 international symposium on Symbolic and algebraic computation ISSAC 09, 55. ACM Press. Retrieved from http://arxiv.org/abs/0707.2347 Use Table 2 */ class StrassenMatrixComputor { public: StrassenMatrixComputor(Backend* bn, bool multithread, int maxDepth); virtual ~StrassenMatrixComputor(); /* It's assume that: P = core->pack A is a matrix where each element is a (P,1) vector : [l/P], e, P B is a matrix where each element is a (hP,1) vector : h, l, hP inputs[0] is the transpose of A: AT, inputs[1] is the transpose of B: BT outputs[0] is the transpose of C: CT C is a matrix where each element is a (P,1) vector, the same as A : [h/P], e, P if (inputs.size() > 2) { inputs[2] is origin CO: CT CO can be the same same as C or broadcast in lenght(1): hC4, e, P or hC4, 1, P } Compute: C = alpha * AB + beta * CO , alpha must be 1.0f postParameters: 0: alpha 1: beta 2: min 3: max if (postParameters.empty()) { alpha = 1.0f beta = 0.0f; min = -FLT_MAX max = FLT_MAX } */ ErrorCode onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const std::vector<float>& postParameters = {}, int l = 0, int h = 0); ErrorCode onEncode(int e, int l, int h, int as, int bs, int cs, const MemChunk AT, const MemChunk BT, MemChunk CT, bool useBias, const MemChunk Bias = MemChunk(), const std::vector<float>& postParameters = {}); // ErrorCode onEncode(int e, int l, int h, int as, int bs, int cs, const uint8_t* AT, const uint8_t* BT, uint8_t* CT, bool useBias, const uint8_t* Bias = nullptr, const std::vector<float>& postParameters = {}); void onExecute(const uint8_t* AT = nullptr, const uint8_t* BT = nullptr, const uint8_t* COT = nullptr, uint8_t* CT = nullptr); void onReset(); protected: Backend* backend() const { return mBackend; } private: struct MatrixInfo { int stackIndex; int offsetBytes; int lineStrideBytes; }; ErrorCode _generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, const std::vector<float>& postParameters); ErrorCode _generateTrivalMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters); ErrorCode _generateBasicMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters); std::vector<std::pair<std::function<void(int tId)>, int>> mFunctions; int mMaxDepth; bool mSupportMultiThread; Backend* mBackend; std::vector<MemChunk> mStack; int mWeightBytes = 4; }; } // namespace MNN #endif /* StrassenMatmulComputor_hpp */