source/backend/cpu/compute/StrassenMatmulComputor.cpp (508 lines of code) (raw):
//
// StrassenMatmulComputor.cpp
// MNN
//
// Created by MNN on 2019/02/11.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include "StrassenMatmulComputor.hpp"
#include "DenseConvolutionTiledExecutor.hpp"
#include "CommonOptFunction.h"
#include "backend/cpu/CPUBackend.hpp"
#include <string.h>
#include <limits.h>
#include "core/AutoStorage.h"
#include "core/Macro.h"
#include "core/Concurrency.h"
#include "core/TensorUtils.hpp"
//#define MNN_OPEN_TIME_TRACE
#include <MNN/AutoTime.hpp>
#include "math/Vec.hpp"
#include "math/Matrix.hpp"
#include "core/BufferAllocator.hpp"
namespace MNN {
class AutoMemory {
public:
AutoMemory(int size, BufferAllocator* allocator) {
mContent = allocator->alloc(size);
mAllocator = allocator;
}
~ AutoMemory() {
if (!mContent.invalid()) {
mAllocator->free(mContent);
}
}
const MemChunk& get() const {
return mContent;
}
private:
MemChunk mContent;
BufferAllocator* mAllocator;
};
StrassenMatrixComputor::StrassenMatrixComputor(Backend* bn, bool multithread, int maxDepth) : mBackend(bn) {
mMaxDepth = maxDepth;
mSupportMultiThread = multithread;
auto core = static_cast<CPUBackend*>(backend())->functions();
mWeightBytes = core->bytes;
};
StrassenMatrixComputor::~StrassenMatrixComputor() {
// Do nothing
}
ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& active) {
// Generate Trival Matrix Multiply
MNN_ASSERT(e > 0);
auto core = static_cast<CPUBackend*>(backend())->functions();
int bytes = core->bytes;
auto aStride = AT.lineStrideBytes;
auto bStride = BT.lineStrideBytes;
auto cStride = CT.lineStrideBytes;
int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
auto bExtraStride = bStride - UP_DIV(l, lP)*lP*hP * mWeightBytes;
MNN_ASSERT(bExtraStride >= 0);
auto tileBufferBasic = static_cast<CPUBackend*>(backend())->getBufferAllocator()->alloc(numberThread * UP_DIV(l, lP) * eP * lP * bytes);
if (tileBufferBasic.invalid()) {
return OUT_OF_MEMORY;
}
int unitNumber = e / eP;
int xCount = e - unitNumber * eP;
auto eReal = aStride / core->bytes / core->pack;
auto matmulUnit = core->MNNPackedMatMul;
auto matmulRemain = core->MNNPackedMatMulRemain;
mFunctions.emplace_back(
std::make_pair([cStride, l, h, xCount, AT, BT, CT, COT, tileBufferBasic, unitNumber, bExtraStride, numberThread, eReal, eP, active, matmulUnit, matmulRemain, this](int tId) {
auto core = static_cast<CPUBackend*>(backend())->functions();
size_t parameters[7];
parameters[0] = xCount * core->bytes;
parameters[1] = l;
parameters[2] = h;
parameters[3] = cStride;
parameters[4] = 0;
parameters[5] = bExtraStride;
parameters[6] = 0;
auto tileHost = tileBufferBasic.ptr() + eP * parameters[1] * tId * core->bytes;
const float* postParametersPtr = nullptr;
if (!active.empty()) {
postParametersPtr = active.data();
}
auto aHost = mStack[AT.stackIndex].ptr() + AT.offsetBytes;
auto bHost = mStack[BT.stackIndex].ptr() + BT.offsetBytes;
auto cHost = mStack[CT.stackIndex].ptr() + CT.offsetBytes;
const uint8_t* biasPtr = nullptr;
if (-1 != COT.stackIndex) {
biasPtr = mStack[COT.stackIndex].ptr() + COT.offsetBytes;
}
auto packUnit = core->bytes * core->pack;
int32_t info[4];
int32_t stride[4];
stride[0] = eP;
stride[1] = (int32_t)parameters[1];
stride[2] = 0;
stride[3] = 0;
info[0] = 1;
info[1] = eReal;
info[2] = eP;
info[3] = 1;
for (int i = tId; i < unitNumber; i+=numberThread) {
int xStart = i * eP;
auto aStart = aHost + xStart * packUnit;
core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
matmulUnit((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, parameters, postParametersPtr, (const float*)biasPtr, nullptr, nullptr);
}
if (tId != numberThread -1) {
return;
}
if (xCount > 0) {
stride[0] = xCount;
stride[1] = (int32_t)parameters[1];
info[2] = xCount;
int xStart = unitNumber * eP;
auto aStart = aHost + xStart * packUnit;
// Copy
core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride);
matmulRemain((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, xCount, parameters, postParametersPtr, (const float*)biasPtr, nullptr, nullptr);
}
}, numberThread));
static_cast<CPUBackend*>(backend())->getBufferAllocator()->free(tileBufferBasic);
return NO_ERROR;
}
#define MNNMATRIX_SUB_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
{\
auto c = c_;\
auto b = b_;\
auto a = a_;\
for (int y = tId; y < lSub; y+=numberThread) {\
core->MNNMatrixSub((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
}\
}
#define MNNMATRIX_ADD_MULTITHREAD(c_, a_, b_, widthC4, cStride, aStride, bStride, lSub, core) \
{\
auto c = c_;\
auto b = b_;\
auto a = a_;\
for (int y = tId; y < lSub; y+=numberThread) {\
core->MNNMatrixAdd((float*)(c + y * cStride), (float*)(a + y * aStride), (float*)(b + y * bStride), widthC4, 0, 0, 0, 1);\
}\
}
ErrorCode StrassenMatrixComputor::_generateBasicMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, const std::vector<float>& postParameters) {
auto core = static_cast<CPUBackend*>(backend())->functions();
int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
int lLimit = 32768 / (std::min(eP, e) + hP);
if (l <= lLimit) {
return _generateTrivalMatMul(e, l, h, AT, BT, CT, COT, postParameters);
}
{
auto lUnit = std::max(lP, core->pack);
lLimit = lLimit / lUnit * lUnit;
}
int unit = UP_DIV(l, lLimit);
auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
AutoMemory CAddr(e * UP_DIV(h, core->pack) * core->pack * core->bytes, allocator);
MatrixInfo CTemp;
CTemp.stackIndex = (int)mStack.size();
CTemp.offsetBytes = 0;
CTemp.lineStrideBytes = e * core->bytes * core->pack;
mStack.emplace_back(CAddr.get());
MatrixInfo Empty;
Empty.stackIndex = -1;
auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
auto cHeight = UP_DIV(h, core->pack);
for (int i=0; i<unit; ++i) {
int lS = i * lLimit;
int lE = lS + lLimit;
if (lE > l) {
lE = l;
}
if (0 == i) {
// First write to output
auto code = _generateTrivalMatMul(e, lE-lS, h, AT, BT, CT, Empty, {});
if (NO_ERROR != code) {
return code;
}
continue;
}
MatrixInfo tempA = AT;
MatrixInfo tempB = BT;
tempA.offsetBytes = AT.offsetBytes + lS / core->pack * AT.lineStrideBytes;
// tempB.offsetBytes = BT.offsetBytes + lS * hP * core->bytes;
tempB.offsetBytes = BT.offsetBytes + lS * hP * mWeightBytes;
auto code = _generateTrivalMatMul(e, lE-lS, h, tempA, tempB, CTemp, Empty, {});
if (NO_ERROR != code) {
return code;
}
// Add CTemp to C
auto f1 = [CT, CTemp, e, cHeight, numberThread, core, this](int tId) {
auto c11Ptr = mStack[CT.stackIndex].ptr() + CT.offsetBytes;
auto xAddr = mStack[CTemp.stackIndex].ptr() + CTemp.offsetBytes;
MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, e, CT.lineStrideBytes, CT.lineStrideBytes, CTemp.lineStrideBytes, cHeight, core);
};
mFunctions.emplace_back(std::make_pair(f1, numberThread));
}
if (!postParameters.empty() && COT.stackIndex >= 0) {
if (1 == numberThread) {
auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
auto biasPtr = (const float*)(mStack[COT.stackIndex].ptr() + COT.offsetBytes);
auto width = e;
auto height = cHeight;
auto c11Ptr = mStack[CT.stackIndex].ptr() + CT.offsetBytes;
core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, CT.lineStrideBytes / core->bytes, CT.lineStrideBytes / core->bytes, height, postParameters.data());
};
mFunctions.emplace_back(std::make_pair(postFunction, 1));
} else {
auto postFunction = [CT, COT, e, cHeight, numberThread, postParameters, core, this](int tId) {
auto width = e;
auto height = cHeight;
auto c11Ptr = mStack[CT.stackIndex].ptr() + CT.offsetBytes;
auto biasPtr = mStack[COT.stackIndex].ptr() + COT.offsetBytes;
for (int y = tId; y < height; y+=numberThread) {
core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * CT.lineStrideBytes), (float*)(c11Ptr + y * CT.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
}
};
mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
}
}
return NO_ERROR;
}
ErrorCode StrassenMatrixComputor::_generateMatMul(int e, int l, int h, const MatrixInfo& AT, const MatrixInfo& BT, const MatrixInfo& CT, const MatrixInfo& COT, int currentDepth, const std::vector<float>& postParameters) {
auto core = static_cast<CPUBackend*>(backend())->functions();
auto aUnit = core->pack;
auto numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1;
int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
MNN_ASSERT(hP % core->pack == 0 || core->pack % hP == 0);
auto eSub = (e / eP) / 2 * eP;
auto lMinDiv = std::max(core->pack * 2, 2 * lP);
auto hSub = (h / std::max(hP, core->pack)) / 2 * std::max(hP, core->pack);
auto remainH = h - hSub * 2;
auto remainE = e - eSub * 2;
int packHUnit = 1;
if (core->pack > hP) {
packHUnit = core->pack / hP;
}
if (currentDepth >= mMaxDepth || eSub == 0 || hSub == 0 || l % (2 * core->pack) != 0 || l % (2 * lP) || l % (2 * packHUnit) != 0) {
return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
}
auto lSub = l / 2;
auto lSubUnit = lSub / core->pack;
auto bWidth = lSub * hP / core->pack;
auto aHeight = lSub / core->pack;
auto cHeight = hSub / core->pack;
auto bHeight = hSub / hP;
/*
Compute the memory read / write cost for expand
*/
auto bHSub = bHeight;
float AComputeCost = 4 * ((float)eSub * lSub);
float BComputeCost = 4 * (float)lSub * bHSub * hP;
float CComputeCost = 7 * (float)eSub * hSub;
float saveMatMulCost = (e / eP) * (aUnit * eP * hSub / core->pack + lSubUnit * eP * aUnit + lSub * bHSub * hP);
const float penalty = core->penalty;//FIXME: Find beter way to set it
float saveCost = saveMatMulCost - (AComputeCost + BComputeCost + CComputeCost) * penalty;
if (saveCost <= 0.0f) {
return _generateBasicMatMul(e, l, h, AT, BT, CT, COT, postParameters);
}
// Strassen Construct
auto bn = backend();
auto allocator = static_cast<CPUBackend*>(backend())->getBufferAllocator();
currentDepth += 1;
auto maxlH = std::max(lSub, hSub);
AutoMemory YAddr(hSub * lSub * mWeightBytes, allocator);
AutoMemory XAddr(maxlH * eSub * core->bytes, allocator);
if (XAddr.get().invalid() || YAddr.get().invalid()) {
return OUT_OF_MEMORY;
}
MatrixInfo Y;
Y.stackIndex = (int)mStack.size();
mStack.emplace_back(YAddr.get());
Y.offsetBytes = 0;
Y.lineStrideBytes = lSub * mWeightBytes * hP;
MatrixInfo X;
X.stackIndex = (int)mStack.size();
X.offsetBytes = 0;
X.lineStrideBytes = eSub * core->bytes * core->pack;
mStack.emplace_back(XAddr.get());
MatrixInfo CX;
CX.stackIndex = X.stackIndex;
CX.offsetBytes = 0;
CX.lineStrideBytes = eSub * core->bytes * core->pack;
MatrixInfo a11 = AT;
MatrixInfo a12 = AT;
a12.offsetBytes = AT.offsetBytes + AT.lineStrideBytes * lSubUnit;
MatrixInfo a21 = AT;
a21.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes;
MatrixInfo a22 = AT;
a22.offsetBytes = AT.offsetBytes + eSub * core->pack * core->bytes + AT.lineStrideBytes * lSubUnit;
MatrixInfo b11 = BT;
MatrixInfo b12 = BT;
b12.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP);
MatrixInfo b21 = BT;
b21.offsetBytes = BT.offsetBytes + lSub * hP * mWeightBytes;
MatrixInfo b22 = BT;
b22.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (hSub / hP) + lSub * hP * mWeightBytes;
MatrixInfo c11 = CT;
MatrixInfo c12 = CT;
c12.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (hSub / core->pack);
MatrixInfo c21 = CT;
c21.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes;
MatrixInfo c22 = CT;
c22.offsetBytes = CT.offsetBytes + eSub * core->pack * core->bytes + CT.lineStrideBytes * (hSub / core->pack);
MatrixInfo Empty;
Empty.stackIndex = -1;
{
// S3=A11-A21, T3=B22-B12, P7=S3*T3
auto f = [a11, a21, b22, b12, X, Y, eSub, lSub, hSub, numberThread, core, hP, this, bWidth, aHeight, bHeight](int tId) {
auto xAddr = mStack[X.stackIndex].ptr() + X.offsetBytes;
auto yAddr = mStack[Y.stackIndex].ptr() + Y.offsetBytes;
auto a11Ptr = mStack[a11.stackIndex].ptr() + a11.offsetBytes;
auto a21Ptr = mStack[a21.stackIndex].ptr() + a21.offsetBytes;
MNNMATRIX_SUB_MULTITHREAD(xAddr, a11Ptr, a21Ptr, eSub, X.lineStrideBytes, a11.lineStrideBytes, a21.lineStrideBytes, aHeight, core);
MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex].ptr() + b22.offsetBytes, mStack[b12.stackIndex].ptr() + b12.offsetBytes, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, b12.lineStrideBytes, bHeight, core);
};
mFunctions.emplace_back(std::make_pair(f, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c21, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
}
{
// S1=A21+A22, T1=B12-B11, P5=S1T1
auto f = [a22, a21, b11, b12, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
MNNMATRIX_ADD_MULTITHREAD(mStack[X.stackIndex].ptr() + X.offsetBytes, mStack[a21.stackIndex].ptr() + a21.offsetBytes, mStack[a22.stackIndex].ptr() + a22.offsetBytes , eSub, X.lineStrideBytes, a21.lineStrideBytes, a22.lineStrideBytes, aHeight, core);
MNNMATRIX_SUB_MULTITHREAD(mStack[Y.stackIndex].ptr() + Y.offsetBytes, mStack[b12.stackIndex].ptr() + b12.offsetBytes, mStack[b11.stackIndex].ptr() + b11.offsetBytes, bWidth, Y.lineStrideBytes, b12.lineStrideBytes, b11.lineStrideBytes, bHeight, core);
};
mFunctions.emplace_back(std::make_pair(f, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c22, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
}
{
// S2=S1-A11, T2=B22-T1, P6=S2T2
auto f = [a11, b22, X, Y, eSub, lSub, hSub, numberThread, hP, core, this, bWidth, aHeight, bHeight](int tId) {
auto xAddr = mStack[X.stackIndex].ptr() + X.offsetBytes;
auto yAddr = mStack[Y.stackIndex].ptr() + Y.offsetBytes;
MNNMATRIX_SUB_MULTITHREAD(xAddr, xAddr, mStack[a11.stackIndex].ptr() + a11.offsetBytes, eSub, X.lineStrideBytes, X.lineStrideBytes, a11.lineStrideBytes, aHeight, core);
MNNMATRIX_SUB_MULTITHREAD(yAddr, mStack[b22.stackIndex].ptr() + b22.offsetBytes, yAddr, bWidth, Y.lineStrideBytes, b22.lineStrideBytes, Y.lineStrideBytes, bHeight, core);
};
mFunctions.emplace_back(std::make_pair(f, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, X, Y, c12, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
}
{
// S4=A12-S2, P3=S4*B22, P1=A11*B11
auto f = [a12, X, eSub, aHeight, numberThread, core, this](int tId) {
auto xAddr = mStack[X.stackIndex].ptr() + X.offsetBytes;
MNNMATRIX_SUB_MULTITHREAD(xAddr, mStack[a12.stackIndex].ptr() + a12.offsetBytes, xAddr, eSub, X.lineStrideBytes, a12.lineStrideBytes, X.lineStrideBytes, aHeight, core);
};
mFunctions.emplace_back(std::make_pair(f, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, X, b22, c11, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
code = _generateMatMul(eSub, lSub, hSub, a11, b11, CX, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
}
{
// U2=P1+P6, U3=U2+P7, U4=U2+P5, U7=U3+P5
// U5=U4+P3, T4=T2-B21, P4=A22*T4
auto f = [c11, c12, c21, c22, b21, X, Y, eSub, bWidth, cHeight, bHeight, numberThread, core, this](int tId) {
for (int y = tId; y < cHeight; y+=numberThread) {
core->MNNStrassenMergeCFunction((float*)(mStack[c11.stackIndex].ptr() + c11.offsetBytes + y * c11.lineStrideBytes), (float*)(mStack[c12.stackIndex].ptr() + c12.offsetBytes + y * c12.lineStrideBytes), (float*)(mStack[c21.stackIndex].ptr() + c21.offsetBytes + y * c21.lineStrideBytes), (float*)(mStack[c22.stackIndex].ptr() + c22.offsetBytes + y * c22.lineStrideBytes), (float*)(mStack[X.stackIndex].ptr() + X.offsetBytes + y * X.lineStrideBytes), 0, eSub, 1);
}
auto yAddr = mStack[Y.stackIndex].ptr() + Y.offsetBytes;
MNNMATRIX_SUB_MULTITHREAD(yAddr, yAddr, mStack[b21.stackIndex].ptr() + b21.offsetBytes, bWidth, Y.lineStrideBytes, Y.lineStrideBytes, b21.lineStrideBytes, bHeight, core);
};
mFunctions.emplace_back(std::make_pair(f, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, a22, Y, c11, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
}
{
// U6=U3-P4, P2=A12*B21, U1=P1+P2
auto f0 = [c11, c21, eSub, cHeight, numberThread, core, this](int tId) {
auto cw = eSub;
auto c21Addr = mStack[c21.stackIndex].ptr() + c21.offsetBytes;
MNNMATRIX_SUB_MULTITHREAD(c21Addr, c21Addr, mStack[c11.stackIndex].ptr() + c11.offsetBytes, cw, c21.lineStrideBytes, c21.lineStrideBytes, c11.lineStrideBytes, cHeight, core);
};
mFunctions.emplace_back(std::make_pair(f0, numberThread));
auto code = _generateMatMul(eSub, lSub, hSub, a12, b21, c11, Empty, currentDepth, {});
if (code != NO_ERROR) {
return code;
}
auto f1 = [c11, X, eSub, cHeight, numberThread, core, this](int tId) {
auto cw = eSub;
auto c11Ptr = mStack[c11.stackIndex].ptr() + c11.offsetBytes;
auto xAddr = mStack[X.stackIndex].ptr() + X.offsetBytes;
MNNMATRIX_ADD_MULTITHREAD(c11Ptr, c11Ptr, xAddr, cw, c11.lineStrideBytes, c11.lineStrideBytes, X.lineStrideBytes, cHeight, core);
};
mFunctions.emplace_back(std::make_pair(f1, numberThread));
if (!postParameters.empty() && COT.stackIndex >= 0) {
if (1 == numberThread) {
auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
auto biasPtr = (const float*)(mStack[COT.stackIndex].ptr() + COT.offsetBytes);
auto width = eSub * 2;
auto height = cHeight * 2;
auto c11Ptr = mStack[c11.stackIndex].ptr() + c11.offsetBytes;
core->MNNAxByClampBroadcastUnit((float*)c11Ptr, (float*)c11Ptr, biasPtr, width, c11.lineStrideBytes / core->bytes, c11.lineStrideBytes / core->bytes, height, postParameters.data());
};
mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
} else {
auto postFunction = [c11, COT, eSub, cHeight, numberThread, postParameters, core, this](int tId) {
auto width = eSub * 2;
auto height = cHeight * 2;
auto c11Ptr = mStack[c11.stackIndex].ptr() + c11.offsetBytes;
auto biasPtr = mStack[COT.stackIndex].ptr() + COT.offsetBytes;
for (int y = tId; y < height; y+=numberThread) {
core->MNNAxByClampBroadcastUnit((float*)(c11Ptr + y * c11.lineStrideBytes), (float*)(c11Ptr + y * c11.lineStrideBytes), (const float*)(biasPtr + y * core->bytes * core->pack), width, 0, 0, 1, postParameters.data());
}
};
mFunctions.emplace_back(std::make_pair(postFunction, numberThread));
}
}
}
if (remainH > 0) {
auto lastH = hSub * 2;
MatrixInfo CLast = CT;
CLast.offsetBytes = CT.offsetBytes + CT.lineStrideBytes * (lastH / core->pack);
MatrixInfo BLast = BT;
BLast.offsetBytes = BT.offsetBytes + BT.lineStrideBytes * (lastH / hP);
MatrixInfo Bias = COT;
if (Bias.stackIndex >= 0) {
Bias.offsetBytes = COT.offsetBytes + core->bytes * lastH;
}
auto code = _generateBasicMatMul(eSub * 2, l, remainH, AT, BLast, CLast, Bias, postParameters);
if (NO_ERROR != code) {
return code;
}
}
if (remainE > 0) {
MatrixInfo CLast = CT;
CLast.offsetBytes = CT.offsetBytes + eSub * 2 * core->pack * core->bytes;
MatrixInfo ALast = AT;
ALast.offsetBytes = AT.offsetBytes + eSub * 2 * core->pack * core->bytes;
auto code = _generateBasicMatMul(remainE, l, h, ALast, BT, CLast, COT, postParameters);
if (NO_ERROR != code) {
return code;
}
}
return NO_ERROR;
}
void StrassenMatrixComputor::onReset() {
mStack.clear();
mFunctions.clear();
}
ErrorCode StrassenMatrixComputor::onEncode(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, const std::vector<float>& postParameters, int inputL, int inputH) {
auto core = static_cast<CPUBackend*>(backend())->functions();
mWeightBytes = core->bytes;
MNN_ASSERT(inputs.size() == 2 || inputs.size() == 3);
MNN_ASSERT(outputs.size() == 1);
auto A = inputs[0];
auto B = inputs[1];
auto C = outputs[0];
auto l = B->length(1);
if (inputL != 0) {
l = inputL;
}
auto e = A->length(1);
auto h = std::min(C->length(0) * core->pack, B->length(0) * B->length(2));
if (inputH != 0) {
h = inputH;
}
int as = A->stride(0);
int eP, lP, hP;
core->MNNGetMatMulPackMode(&eP, &lP, &hP);
int bs = UP_DIV(l, lP) * lP * hP;
int cs = C->stride(0);
MemChunk bias;
bool useBias = false;
if (inputs.size() > 2) {
bias = TensorUtils::getDescribeOrigin(inputs[2])->mem->chunk();
useBias = true;
}
return onEncode(e, l, h, as, bs, cs, TensorUtils::getDescribeOrigin(A)->mem->chunk(), TensorUtils::getDescribeOrigin(B)->mem->chunk(), TensorUtils::getDescribeOrigin(C)->mem->chunk(), useBias, bias, postParameters);
}
ErrorCode StrassenMatrixComputor::onEncode(int e, int l, int h, int as, int bs, int cs, const MemChunk AT, const MemChunk BT, MemChunk CT, bool useBias, const MemChunk Bias, const std::vector<float>& postParameters) {
auto core = static_cast<CPUBackend*>(backend())->functions();
MatrixInfo a,b,c,bias;
bias.stackIndex = -1;
mFunctions.clear();
mStack = {AT, BT, CT};
if (useBias) {
bias.stackIndex = 3;
bias.offsetBytes = 0;
mStack.emplace_back(Bias);
}
a.stackIndex = 0;
a.lineStrideBytes = as * core->bytes;
a.offsetBytes = 0;
b.stackIndex = 1;
b.lineStrideBytes = bs * mWeightBytes;
b.offsetBytes = 0;
c.stackIndex = 2;
c.lineStrideBytes = cs * core->bytes;
c.offsetBytes = 0;
return _generateMatMul(e, l, h, a, b, c, bias, 0, postParameters);
}
void StrassenMatrixComputor::onExecute(const uint8_t* AT, const uint8_t* BT, const uint8_t* COT, uint8_t* CT) {
if (nullptr != AT) {
mStack[0] = (uint8_t*)AT;
}
if (nullptr != BT) {
mStack[1] = (uint8_t*)BT;
}
if (nullptr != CT) {
mStack[2] = (uint8_t*)CT;
}
if (nullptr != COT) {
mStack[3] = (uint8_t*)COT;
}
// All is done in onResize, just execute it
for (auto& f : mFunctions) {
MNN_CONCURRENCY_BEGIN(tId, f.second) {
f.first(tId);
}
MNN_CONCURRENCY_END();
}
}
} // namespace MNN