tools/train/source/nn/NN.cpp (1,135 lines of code) (raw):
//
// NN.cpp
// MNN
//
// Created by MNN on 2019/11/25.
// Copyright © 2018, Alibaba Group Holding Limited
//
#include <MNN/expr/ExecutorScope.hpp>
#include "NN.hpp"
#include "Distributions.hpp"
#include "module/ModuleInside.hpp"
#include "module/PipelineModule.hpp"
#include "module/WhileModule.hpp"
#include "module/IfModule.hpp"
#include "module/NMSModule.hpp"
#include "Initializer.hpp"
#include "MNN_generated.h"
#include "RandomGenerator.hpp"
#include "core/Macro.h"
#include "math/WingoradGenerater.hpp"
#include "core/WinogradInt8Attr.hpp"
#include <string>
using namespace MNN::Express;
namespace MNN {
namespace Express {
static VARP _activate(VARP x, NN::ActivationFunctionType type) {
switch (type) {
case NN::None:
return x;
case NN::Relu:
return _Relu(x);
case NN::Relu6:
return _Relu6(x);
default:
break;
}
return nullptr;
}
class DropoutModule : public Module {
public:
DropoutModule(const float dropRatio) {
mDropRatio = dropRatio;
setType("Dropout");
}
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
Express::VARP x = inputs[0];
if (getIsTraining()) {
float scale = 1. / (1. - mDropRatio);
auto mask = _Input(x->getInfo()->dim, x->getInfo()->order, x->getInfo()->type);
auto maskPtr = mask->writeMap<float>();
auto eltSize = x->getInfo()->size;
Distributions::uniform(eltSize, 0, 1, maskPtr, RandomGenerator::generator());
for (int i = 0; i < eltSize; i++) {
maskPtr[i] = maskPtr[i] < mDropRatio ? 0.0f : scale;
}
x = x * mask;
}
return {x};
}
private:
DropoutModule() = default;
Module* clone(CloneContext* ctx) const override {
DropoutModule* module(new DropoutModule);
module->mDropRatio = mDropRatio;
return this->cloneBaseTo(ctx, module);
}
float mDropRatio;
};
class BatchNormModule : public Module {
public:
BatchNormModule(EXPRP expr, const float m = 0.99) {
MNN_ASSERT(expr->get() != nullptr);
MNN_ASSERT(expr->get()->type() == OpType_BatchNorm);
auto bnPa = expr->get()->main_as_BatchNorm();
auto& inputs = expr->inputs();
int dims = 4;
if (!inputs.empty()) {
auto info = inputs[0]->getInfo();
if (nullptr != info) {
dims = info->dim.size();
}
}
mEps = bnPa->epsilon();
mMomentum = m;
mChannels = bnPa->channels();
std::vector<int> statShape;
std::vector<int> reductionDims;
int channels = mChannels;
if (dims == 2) {
statShape = {1, channels};
mReductionDims = {0};
}
if (dims == 3) {
statShape = {1, channels, 1};
mReductionDims = {0, 2};
}
if (dims == 4) {
statShape = {1, channels, 1, 1};
mReductionDims = {0, 2, 3};
}
MNN_ASSERT(bnPa->biasData()->size() == mChannels);
mBias = _TrainableParam(bnPa->biasData()->data(), statShape, NCHW);
MNN_ASSERT(bnPa->slopeData()->size() == mChannels);
mScale = _TrainableParam(bnPa->slopeData()->data(), statShape, NCHW);
MNN_ASSERT(bnPa->meanData()->size() == mChannels);
mRunningMean = _Const(bnPa->meanData()->data(), statShape, NCHW);
MNN_ASSERT(bnPa->meanData()->size() == mChannels);
mRunningVariance = _Const(bnPa->varData()->data(), statShape, NCHW);
addParameter(mScale);
addParameter(mBias);
mRunningVariancePos = addParameter(mRunningVariance);
mRunningMeanPos = addParameter(mRunningMean);
setType("BatchNorm");
}
BatchNormModule(const int channels, const int dims = 4, const float m = 0.99, const float e = 1e-5) {
mMomentum = m;
mEps = e;
mChannels = channels;
std::vector<int> statShape;
std::vector<int> reductionDims;
if (dims == 2) {
statShape = {1, channels};
mReductionDims = {0};
}
if (dims == 3) {
statShape = {1, channels, 1};
mReductionDims = {0, 2};
}
if (dims == 4) {
statShape = {1, channels, 1, 1};
mReductionDims = {0, 2, 3};
}
mScale = _TrainableParam(1.0f, statShape, NCHW);
mBias = _TrainableParam(0.0f, statShape, NCHW);
mRunningMean = _Const(0.0f, statShape, NCHW);
mRunningVariance = _Const(0.0f, statShape, NCHW);
addParameter(mScale);
addParameter(mBias);
mRunningVariancePos = addParameter(mRunningVariance);
mRunningMeanPos = addParameter(mRunningMean);
setType("BatchNorm");
}
VARP runningMean() {
return mRunningMean;
}
VARP runningVariance() {
return mRunningVariance;
}
VARP scale() {
return mScale;
}
VARP bias() {
return mBias;
}
float eps() {
return mEps;
}
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
Express::VARP x = inputs[0];
auto dimFormat = x->getInfo()->order;
VARP outputData = nullptr;
if (getIsTraining()) {
if (dimFormat == NC4HW4 || dimFormat == NHWC) {
x = _Convert(x, NCHW);
}
MNN_ASSERT(x->getInfo()->dim[1] == mChannels);
auto sampleMean = _ReduceMean(x, mReductionDims, true); // mean for each channel in the batch
auto xSub = x - sampleMean;
auto sampleVar = _ReduceMean(_Square(xSub), mReductionDims,
true); // variance for each channel in the batch
auto rSampleStd = _Reciprocal(_Sqrt(sampleVar + _Const(mEps)));
auto normalizedData = xSub * rSampleStd;
outputData = normalizedData * mScale + mBias;
mRunningMean = _Const(mMomentum) * mRunningMean + _Const(1 - mMomentum) * sampleMean;
mRunningVariance = _Const(mMomentum) * mRunningVariance + _Const(1 - mMomentum) * sampleVar;
outputData->setName(name());
outputData = _Convert(outputData, dimFormat);
setParameter(mRunningMean, mRunningMeanPos);
setParameter(mRunningVariance, mRunningVariancePos);
return {outputData};
}
auto rStd = _Const(1.0f) / _Sqrt(mRunningVariance + _Const(mEps));
auto alpha = rStd * mScale;
auto beta = mBias - mRunningMean * rStd * mScale;
//outputData = (_Convert(x, NCHW) * alpha) + beta;
alpha.fix(VARP::CONSTANT);
beta.fix(VARP::CONSTANT);
//FUNC_PRINT_ALL(alpha->readMap<float>()[0], f);
x = _Convert(x, NC4HW4);
std::vector<float> scale(alpha->getInfo()->size);
std::vector<float> bias(beta->getInfo()->size);
::memcpy(scale.data(), alpha->readMap<float>(), scale.size() * sizeof(float));
::memcpy(bias.data(), beta->readMap<float>(), bias.size() * sizeof(float));
outputData = _Scale(x, mChannels, std::move(scale), std::move(bias));
outputData->setName(name());
outputData = _Convert(outputData, dimFormat);
return {outputData};
}
private:
BatchNormModule() = default;
Module* clone(CloneContext* ctx) const override {
BatchNormModule* module(new BatchNormModule);
module->mMomentum = mMomentum;
module->mEps = mEps;
module->mScale = ctx->getOrClone(mScale);
module->mBias = ctx->getOrClone(mBias);
module->mRunningMean = ctx->getOrClone(mRunningMean);
module->mRunningVariance = ctx->getOrClone(mRunningVariance);
module->mRunningMeanPos = mRunningMeanPos;
module->mRunningVariancePos = mRunningVariancePos;
module->mChannels = mChannels;
module->mReductionDims = mReductionDims;
return this->cloneBaseTo(ctx, module);
}
float mMomentum = 0.99;
float mEps = 1e-5;
VARP mScale = nullptr;
VARP mBias = nullptr;
VARP mRunningMean = nullptr;
VARP mRunningVariance = nullptr;
int mRunningMeanPos = -1;
int mRunningVariancePos = -1;
int mChannels;
std::vector<int> mReductionDims;
};
void NN::ConvOption::reset(int size) {
stride = std::vector<int>(size, 1);
channel = std::vector<int>(size, 0);
kernelSize = std::vector<int>(size, 1);
dilate = std::vector<int>(size, 1);
padMode = VALID;
pads = std::vector<int>(size, 0);
depthwise = false;
fusedActivationFunction = None;
}
class ConvModule : public Module {
public:
ConvModule(const NN::ConvParameters& parameters) {
mParameter = parameters;
if (nullptr != mParameter.bias) {
addParameter(mParameter.bias);
}
if (nullptr != mParameter.weight) {
addParameter(mParameter.weight);
}
setName(parameters.name);
setType("Conv");
}
NN::ConvParameters& convParameters() {
return mParameter;
}
virtual std::vector<VARP> onForward(const std::vector<VARP>& inputs) override {
auto input = inputs[0];
auto& option = mParameter.option;
if (getIsTraining()) {
auto tempOutput = _Conv(mParameter.weight, mParameter.bias, _Convert(input, NC4HW4), option.padMode, option.stride, option.dilate, mParameter.group, mParameter.option.pads);
tempOutput->setName(name());
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return {tempOutput};
}
bool relu = option.fusedActivationFunction == NN::Relu;
bool relu6 = option.fusedActivationFunction == NN::Relu6;
std::vector<float> weight;
std::vector<float> bias;
{
auto weightInfo = mParameter.weight->getInfo();
weight.resize(weightInfo->size);
::memcpy(weight.data(), mParameter.weight->readMap<float>(), weight.size() * sizeof(float));
}
{
bias.resize(mParameter.option.channel[1]);
if (nullptr != mParameter.bias) {
::memcpy(bias.data(), mParameter.bias->readMap<float>(), bias.size() * sizeof(float));
} else {
::memset(bias.data(), 0, bias.size() * sizeof(float));
}
}
auto tempOutput = _Conv(std::move(weight), std::move(bias), _Convert(input, NC4HW4), option.channel, option.kernelSize, option.padMode, option.stride, option.dilate, mParameter.group, mParameter.option.pads, relu, relu6);
tempOutput->setName(name());
return {tempOutput};
}
private:
ConvModule() = default;
Module* clone(CloneContext* ctx) const override {
ConvModule* module(new ConvModule);
module->mParameter = mParameter;
module->mParameter.weight = ctx->getOrClone(mParameter.weight);
module->mParameter.bias = ctx->getOrClone(mParameter.bias);
return this->cloneBaseTo(ctx, module);
}
NN::ConvParameters mParameter;
};
static std::tuple<VARP, VARP, int> _initParameters(const NN::ConvOption& option, bool hasBias,
std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
std::tuple<VARP, VARP, int> defaultRes;
if (nullptr == weightInit) {
weightInit.reset(Initializer::xavier());
}
if (nullptr == biasInit) {
biasInit.reset(Initializer::constValue(0.0f));
}
VARP weight;
int group = 1;
if (option.depthwise) {
if (option.channel[1] != option.channel[0]) {
MNN_ERROR("Can't support not the same channel for convolution depthwise\n");
return defaultRes;
}
weight = weightInit->createConstVar({option.channel[0], 1, option.kernelSize[1], option.kernelSize[0]}, NCHW);
weight.fix(VARP::TRAINABLE);
group = option.channel[0];
} else {
weight = weightInit->createConstVar(
{option.channel[1], option.channel[0], option.kernelSize[1], option.kernelSize[0]}, NCHW);
weight.fix(VARP::TRAINABLE);
}
VARP bias;
if (hasBias) {
bias = biasInit->createConstVar({option.channel[1]}, NCHW);
bias.fix(VARP::TRAINABLE);
}
return std::make_tuple(weight, bias, group);
}
Module* NN::ConvTranspose(const ConvOption& option, bool hasBias,
std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
VARP input = _Input({1, option.channel[0], -1, -1}, NC4HW4);
auto tuple = _initParameters(option, hasBias, weightInit, biasInit);
auto weight = std::get<0>(tuple);
if (nullptr == weight) {
return nullptr;
}
if (!option.depthwise) {
weight = _Transpose(weight, {1, 0, 2, 3});
weight.fix(VARP::TRAINABLE);
}
auto bias = std::get<1>(tuple);
auto group = std::get<2>(tuple);
if (nullptr != bias) {
auto tempOutput = _Deconv(weight, bias, input, option.padMode, option.stride, option.dilate, group);
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return NN::extract({input}, {tempOutput}, true);
}
auto tempOutput = _Deconv(weight, nullptr, input, option.padMode, option.stride, option.dilate, group);
tempOutput = _activate(tempOutput, option.fusedActivationFunction);
return NN::extract({input}, {tempOutput}, true);
}
Module* NN::Conv(const ConvOption& option, bool hasBias, std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
auto tuple = _initParameters(option, hasBias, weightInit, biasInit);
ConvParameters parameters;
parameters.weight = std::get<0>(tuple);
if (nullptr == parameters.weight) {
return nullptr;
}
parameters.bias = std::get<1>(tuple);
parameters.group = std::get<2>(tuple);
parameters.option = option;
return new ConvModule(parameters);
}
Module* NN::Linear(int l, int t, bool hasBias, std::shared_ptr<Initializer> weightInit,
std::shared_ptr<Initializer> biasInit) {
if (nullptr == weightInit) {
weightInit.reset(Initializer::xavier());
}
if (nullptr == biasInit) {
biasInit.reset(Initializer::constValue(0.0f));
}
auto weight = weightInit->createConstVar({t, l}, NCHW);
weight.fix(VARP::TRAINABLE);
// Save lazy mode
auto lazyEval = ExecutorScope::Current()->lazyEval;
auto lazyMode = ExecutorScope::Current()->getLazyMode();
ExecutorScope::Current()->lazyEval = true;
ExecutorScope::Current()->setLazyComputeMode(Executor::LAZY_FULL);
auto input = _Input({l}, NCHW);
auto output = _MatMul(input, weight, false, true);
if (!hasBias) {
return NN::extract({input}, {output}, true);
}
auto bias = biasInit->createConstVar({1, t}, NCHW);
bias.fix(VARP::TRAINABLE);
output = _Add(output, bias);
auto module = NN::extract({input}, {output}, true);
module->setType("Linear");
// Revert lazy mode
ExecutorScope::Current()->lazyEval = lazyEval;
ExecutorScope::Current()->setLazyComputeMode(lazyMode);
return module;
}
Module* NN::Dropout(const float dropRatio) {
return new DropoutModule(dropRatio);
}
Module* NN::BatchNorm(const int channels, const int dims, const float m, const float e) {
return new BatchNormModule(channels, dims, m, e);
}
NN::ConvParameters NN::Utils::ExtractConvolution(EXPRP source) {
ConvParameters _default;
if (source->get() == nullptr) {
return _default;
}
if (source->get()->type() != OpType_Convolution && source->get()->type() != OpType_ConvolutionDepthwise) {
return _default;
}
auto conv2D = source->get()->main_as_Convolution2D();
NN::ConvOption option;
option.kernelSize = {conv2D->common()->kernelX(), conv2D->common()->kernelY()};
option.stride = {conv2D->common()->strideX(), conv2D->common()->strideY()};
if (nullptr != conv2D->common()->pads()) {
option.pads.resize(conv2D->common()->pads()->size());
for (int i=0; i<option.pads.size(); ++i) {
option.pads[i] = conv2D->common()->pads()->data()[i];
}
} else {
option.pads = {conv2D->common()->padX(), conv2D->common()->padY()};
}
switch (conv2D->common()->padMode()) {
case MNN::PadMode_SAME:
option.padMode = SAME;
break;
case MNN::PadMode_VALID:
option.padMode = VALID;
break;
case MNN::PadMode_CAFFE:
option.padMode = CAFFE;
break;
default:
break;
}
option.dilate = {conv2D->common()->dilateX(), conv2D->common()->dilateY()};
option.depthwise = source->get()->type() == OpType_ConvolutionDepthwise;
auto inputCount = conv2D->common()->inputCount();
if (0 == inputCount) {
auto inputInfo = source->inputs()[0]->getInfo();
if (nullptr != inputInfo) {
if (NHWC == inputInfo->order) {
inputCount = source->inputs()[0]->getInfo()->dim[3];
} else {
inputCount = source->inputs()[0]->getInfo()->dim[1];
}
} else {
if (nullptr == conv2D->weight()) {
MNN_ERROR("Can't extract convolution\n");
return _default;
}
auto weightCount = conv2D->weight()->size();
if (option.depthwise) {
inputCount = conv2D->common()->outputCount();
} else {
inputCount = weightCount / conv2D->common()->kernelX() / conv2D->common()->kernelY() / conv2D->common()->outputCount();
}
}
}
option.channel = {inputCount, conv2D->common()->outputCount()};
int group = 1;
if (option.depthwise) {
group = conv2D->common()->outputCount();
}
VARP weight;
auto inputs = source->inputs();
if (inputs.size() > 1) {
weight = inputs[1];
}
VARP bias;
if (inputs.size() > 2) {
bias = inputs[2];
}
if (inputs.size() < 2) {
// Extract Weight And Bias from Conv2D
if (conv2D->weight() == nullptr || conv2D->bias() == nullptr) {
return _default;
}
bias = _TrainableParam(conv2D->bias()->data(), {option.channel[1]}, NCHW);
weight = _TrainableParam(conv2D->weight()->data(), {option.channel[1], option.channel[0] / group, option.kernelSize[1], option.kernelSize[0]}, NCHW);
}
_default.option = std::move(option);
_default.weight = std::move(weight);
_default.bias = std::move(bias);
_default.group = group;
if (conv2D->common()->relu()) {
_default.option.fusedActivationFunction = NN::Relu;
}
if (conv2D->common()->relu6()) {
_default.option.fusedActivationFunction = NN::Relu6;
}
_default.name = source->name();
return _default;
}
Module* NN::Conv(const ConvParameters& parameter) {
return new ConvModule(parameter);
}
Module* NN::Utils::ExtractNotRunableOp(Express::EXPRP expr, const std::map<std::string, SubGraph>& subgraphs) {
if (nullptr == expr->get()) {
return nullptr;
}
if (expr->get()->type() == OpType_BatchNorm) {
return new BatchNormModule(expr);
}
if (expr->get()->type() == OpType_Dropout) {
return new DropoutModule(0.3f);
}
return nullptr;
}
class ConvBNReluFusedModule : public Module {
public:
ConvBNReluFusedModule(std::vector<std::shared_ptr<Module> > modules,
NN::FeatureScaleStatMethod featureScaleStatMethod,
NN::ScaleUpdateMethod scaleUpdateMethod, const int bits, bool winograd = false) {
MNN_ASSERT(modules.size() >= 1);
MNN_ASSERT(modules[0]->type() == "Conv");
if (modules.size() == 3) {
MNN_ASSERT(modules[1]->type() == "BatchNorm");
MNN_ASSERT(modules[2]->type() == "ReLU" || modules[2]->type() == "ReLU6");
}
for (int i = 0; i < modules.size(); i++) {
auto type = modules[i]->type();
if (type == "Conv") {
mConvParameter = std::static_pointer_cast<ConvModule>(modules[i])->convParameters();
mOption = mConvParameter.option;
mGroup = mConvParameter.group;
mWeight = mConvParameter.weight;
mBias = mConvParameter.bias;
if (nullptr != mWeight) {
addParameter(mWeight);
}
if (nullptr != mBias) {
addParameter(mBias);
}
if (winograd && mOption.kernelSize[0] > 1 && mOption.kernelSize[1] > 1
&& mOption.stride[0] == 1 && mOption.stride[1] == 1
&& mOption.dilate[0] == 1 && mOption.dilate[1] == 1 && mGroup == 1) {
mWinogradAttr.reset(new WinogradInt8Attr);
mWinogradTransInputMaxPos = addParameter(mWinogradTransInputMax);
mWinogradTransInputMinPos = addParameter(mWinogradTransInputMin);
mWinogradTransWeightScalePos = addParameter(mWinogradTransInputMax);
}
setName(mConvParameter.name);
modules[i] = nullptr;
} else if (type == "BatchNorm") {
mBatchNorm = modules[i];
registerModel({mBatchNorm});
} else if (type == "ReLU") {
mActivation = NN::Relu;
modules[i] = nullptr;
} else if (type == "ReLU6") {
mActivation = NN::Relu6;
modules[i] = nullptr;
} else {
MNN_ASSERT(false);
}
}
if (mOption.fusedActivationFunction == NN::Relu || mOption.fusedActivationFunction == NN::Relu6) {
mActivation = mOption.fusedActivationFunction;
}
mFeatureScaleStatMethod = NN::PerTensor;
mScaleUpdateMethod = scaleUpdateMethod;
mBits = bits;
mLimit = (float)(1 << (bits - 1)) - 1.0f;
mLimitScale = _Scalar<float>(1.0f / mLimit);
mWeightClampValue = _Scalar<float>(mLimit);
// mInputClampValue = _Scalar<float>(mLimit);
// mOutputClampValue = _Scalar<float>(mLimit);
// lower bits only apply to weights
mInputClampValue = _Scalar<float>((float)(1 << (8 - 1)) - 1.0f);
mOutputClampValue = _Scalar<float>((float)(1 << (8 - 1)) - 1.0f);
mInputMinPos = addParameter(mInputMin);
mInputMaxPos = addParameter(mInputMax);
mOutputMinPos = addParameter(mOutputMin);
mOutputMaxPos = addParameter(mOutputMax);
setType("ConvBNReluFused");
}
std::pair<VARP, VARP> computeScaleAndZeroPoint(VARP min, VARP max, VARP clampVar) {
MNN_ASSERT((!(min == nullptr)));
MNN_ASSERT((!(max == nullptr)));
min = _Minimum(_Scalar<float>(0.0f), min);
max = _Maximum(_Scalar<float>(0.0f), max);
auto scale = (max - min) / (_Scalar(2.0f) * clampVar);
auto zeroPoint = _Round((_Scalar(0.0f) - min) / scale - clampVar);
return std::make_pair(scale, zeroPoint);
}
std::vector<VARP> fakeQuantFeatureWithMinMax(VARP x, VARP useMin, VARP useMax, VARP clampVar, INTS axis = {}) {
auto originFormat = x->getInfo()->order;
auto tempX = x;
if (originFormat == NC4HW4) {
tempX = _Convert(tempX, NCHW);
}
auto originX = tempX;
VARP min, max;
bool keepDims = false;
if (axis.size() > 0) {
// PerChannel for winograd
keepDims = true;
}
min = _ReduceMin(tempX, axis, keepDims);
max = _ReduceMax(tempX, axis, keepDims);
VARP scale, zeroPoint;
VARP nudgeMin, nudgeMax;
if (!(useMin == nullptr)) {
MNN_ASSERT(!(useMax == nullptr));
auto scaleAndZeroPoint = computeScaleAndZeroPoint(useMin, useMax, clampVar);
scale = scaleAndZeroPoint.first;
zeroPoint = scaleAndZeroPoint.second;
} else {
auto scaleAndZeroPoint = computeScaleAndZeroPoint(min, max, clampVar);
scale = scaleAndZeroPoint.first;
zeroPoint = scaleAndZeroPoint.second;
}
float limit = clampVar->readMap<float>()[0];
nudgeMin = (_Scalar<float>(-limit) - zeroPoint) * scale;
nudgeMax = (_Scalar<float>(limit) - zeroPoint) * scale;
nudgeMin = _Minimum(_Scalar<float>(0.0f), nudgeMin);
nudgeMax = _Maximum(_Scalar<float>(0.0f), nudgeMax);
auto quantX = clamp(_Round(tempX / scale + zeroPoint), clampVar);
tempX = scale * (quantX - zeroPoint);
// Break the grad by use cast
tempX = _Cast<float>(tempX);
// Move grad from tempX to originX
tempX = _Convert(tempX + _ZeroGrad(originX), originFormat);
return {tempX, nudgeMin, nudgeMax};
}
VARP clamp(VARP x, VARP clampVar) {
return _Maximum(_Minimum(x, clampVar), _Negative(clampVar));
}
VARP updateParameter(VARP originValue, VARP newValue) const {
if (nullptr == originValue) {
return newValue;
}
auto ptr = originValue->readMap<float>();
if (ptr[0] == -100.0f) {
return newValue;
}
switch (mScaleUpdateMethod) {
case NN::MovingAverage:
return originValue * _Scalar<float>(mMomentum) + newValue * _Scalar<float>(1.0f-mMomentum);
case NN::Maximum:
return _Maximum(originValue, newValue);
default:
break;
}
MNN_ASSERT(false);
return nullptr;
}
bool bestWinogradUnit(const VARP x, int* unitH = nullptr, int* unitW = nullptr) {
if (x->getInfo() == nullptr) {
return false;
}
int kernelW = mOption.kernelSize[0], kernelH = mOption.kernelSize[1], padW = mOption.pads[0], padH = mOption.pads[1];
int outH = x->getInfo()->dim[2] + 2 * padH - kernelH + 1, outW = x->getInfo()->dim[3] + 2 * padW - kernelW + 1;
int inChannel = mOption.channel[0], outChannel = mOption.channel[1];
int threadNumber = 1, ePack = 12;
int unit2 = UP_DIV(outH * outW, ePack * threadNumber);
int maxUnit = (int)::sqrtf((float)unit2);
const int MAX_UNIT = 6, MIN_UNIT = 2;
maxUnit = std::max(std::min(maxUnit, MAX_UNIT), MIN_UNIT);
auto units = std::pair<int, int>({0, 0});
float maxRate = 2.0f, originCost = outH * outW * inChannel * outChannel * kernelH * kernelW;
std::set<int> supportSu{4, 6};
for (int uh = MIN_UNIT; uh <= maxUnit; ++uh) {
for (int uw = MIN_UNIT; uw <= maxUnit; ++uw) {
auto alphaH = uh + kernelH - 1, alphaW = uw + kernelW - 1;
if (supportSu.find(alphaH) == supportSu.end() || supportSu.find(alphaW) == supportSu.end()) {
continue;
}
float winogradCost =
(2 * alphaH * alphaW * inChannel + alphaH * alphaW * inChannel * outChannel + (alphaH * alphaW + uh * alphaW) * outChannel) * (UP_DIV(outW, uw) * UP_DIV(outH, uh));
float reduceRate = originCost / winogradCost;
if (reduceRate > maxRate) {
maxRate = reduceRate;
units = std::pair<int, int>({uh, uw});
}
}
}
if (units.first == 0 || units.second == 0) {
return false;
}
if (unitH != nullptr && unitW != nullptr) {
*unitH = units.first;
*unitW = units.second;
}
return true;
}
VARP _winogradConv(const VARP x, const VARP weight) {
auto inDims = x->getInfo()->dim;
int batch = inDims[0], inH = inDims[2], inW = inDims[3];
int inChannel = mOption.channel[0], outChannel = mOption.channel[1];
int kernelW = mOption.kernelSize[0], kernelH = mOption.kernelSize[1], padW = mOption.pads[0], padH = mOption.pads[1];
int outH = inH + 2 * padH - kernelH + 1, outW = inW + 2 * padW - kernelW + 1;
int unitH, unitW;
bestWinogradUnit(x, &unitH, &unitW);
if (mWinogradAttr->attrs.empty()) {
mWinogradAttr->add(0, 0, kernelH, kernelW, unitH, unitW);
}
if (unitH != mWinogradAttr->attrs[0].unitY || unitW != mWinogradAttr->attrs[0].unitX) {
MNN_ERROR("Winograd Conv not support variable input shape\n");
return nullptr;
}
int alphaH = unitH + kernelH - 1, alphaW = unitW + kernelW - 1;
int unitNumH = UP_DIV(outH, unitH), unitNumW = UP_DIV(outW, unitW);
int needH = unitNumH * unitH + kernelH - 1, needW = unitNumW * unitW + kernelW - 1;
int paddings[] = {0, 0, 0, 0, padH, needH - inH - padH, padW, needW - inW - padW};
auto xx = _Pad(x, _Const(paddings, {8}, NCHW, halide_type_of<int32_t>()));
// [ic * alphaH * alphaW, N * h_unit_num * w_unit_num]
xx = _Im2Col(xx, {alphaW, alphaH}, {1, 1}, {0, 0}, {unitW, unitH});
// [N * h_unit_num * w_unit_num, ic, alphaH, alphaW]
xx = _Transpose(_Reshape(xx, {inChannel, alphaH, alphaW, -1}), {3, 0, 1, 2});
// Must be the same as ConvInt8Winograd.cpp
Math::WinogradGenerater genH(unitH, kernelH, 1, true), genW(unitW, kernelW, 1, true);
auto srcTransH = _Const(genH.B()->host<void>(), {alphaH, alphaH}, NCHW);
auto srcTransW = _Const(genW.B()->host<void>(), {alphaW, alphaW}, NCHW);
xx = _MatMul(_MatMul(_Transpose(srcTransH, {1, 0}), xx), srcTransW);
// [alphaH * alphaW, ic, N * h_unit_num * w_unit_num]
xx = _Reshape(_Transpose(xx, {2, 3, 1, 0}), {alphaH * alphaW, inChannel, -1});
auto inputPair = fakeQuantFeatureWithMinMax(xx, nullptr, nullptr, mInputClampValue, {1, 2, 3});
mWinogradTransInputMin = updateParameter(mWinogradTransInputMin, inputPair[1]);
mWinogradTransInputMax = updateParameter(mWinogradTransInputMax, inputPair[2]);
setParameter(mWinogradTransInputMin, mWinogradTransInputMinPos);
setParameter(mWinogradTransInputMax, mWinogradTransInputMaxPos);
auto wTransH = _Const(genH.G()->host<void>(), {alphaH, kernelH}, NCHW);
auto wTransW = _Const(genW.G()->host<void>(), {alphaW, kernelW}, NCHW);
// [oc, ic, alphaH, alphaW]
auto ww = _MatMul(_MatMul(wTransH, weight), _Transpose(wTransW, {1, 0}));
// [alphaH * alphaW, oc, ic]
ww = _Transpose(_Reshape(ww, {outChannel, inChannel, -1}), {2, 0, 1});
auto wwInfo = ww->getInfo();
// simulate weight quant
auto weightScale = _Maximum(_ReduceMax(_Abs(ww), {2}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
// ww = clamp(_Round(ww * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
setParameter(weightScale, mWinogradTransWeightScalePos);
// [alphaH * alphaW, oc, N * h_unit_num * w_unit_num]
auto yy = _MatMul(ww, xx);
// [oc, N * h_unit_num * w_unit_num, alphaH, alphaW]
yy = _Reshape(_Transpose(yy, {1, 2, 0}), {outChannel, -1, alphaH, alphaW});
auto dstTransH = _Const(genH.A()->host<void>(), {alphaH, unitH}, NCHW);
auto dstTransW = _Const(genW.A()->host<void>(), {alphaW, unitW}, NCHW);
// [oc, N * h_unit_num * w_unit_num, unitH, unitW]
yy = _MatMul(_MatMul(_Transpose(dstTransH, {1, 0}), yy), dstTransW);
// [N, oc, h_unit_num * unitH, w_unit_num * unitW]
yy = _Reshape(_Transpose(_Reshape(yy, {outChannel, batch, unitNumH, unitNumW, unitH, unitW}), {1, 0, 2, 4, 3, 5}), {batch, outChannel, unitNumH * unitH, unitNumW * unitW});
int sliceStartData[] = {0, 0, 0, 0}, sliceEndData[] = {-1, -1, outH, outW};
yy = _Slice(yy, _Const(sliceStartData, {4}, NCHW), _Const(sliceEndData, {4}, NCHW));
// TODO: add operator!= to VARP
if (!(mBias == nullptr)) {
yy = yy + _Reshape(mBias, {1, -1, 1, 1});
}
return yy;
}
virtual std::vector<Express::VARP> onForward(const std::vector<Express::VARP>& inputs) override {
VARP res;
if (getIsTraining()) {
auto x = _Convert(inputs[0], NCHW);
// simulate weight quant
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
weightTemp = weightTemp + _ZeroGrad(mWeight);
// simulate input quant to get original input scale
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
mInputMin = updateParameter(mInputMin, inputPair[1]);
mInputMax = updateParameter(mInputMax, inputPair[2]);
setParameter(mInputMin, mInputMinPos);
setParameter(mInputMax, mInputMaxPos);
// simulate output quant to get original output scale
if (mWinogradAttr != nullptr && bestWinogradUnit(x)) {
res = _winogradConv(x, weightTemp);
#ifdef MNN_WINOGRAD_DEBUG
VARP res2 = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
mOption.dilate, mGroup, mOption.pads);
auto diff = res2 - res;
diff = diff * diff;
FUNC_PRINT_ALL(_ReduceMax(diff)->readMap<float>()[0], f);
#endif
} else {
res = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
mOption.dilate, mGroup, mOption.pads);
}
res->setName(name());
if (mBatchNorm) {
res = mBatchNorm->forward(res);
}
res = _activate(res, mActivation);
auto outputPair = fakeQuantFeatureWithMinMax(res, nullptr, nullptr, mOutputClampValue);
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
setParameter(mOutputMin, mOutputMinPos);
setParameter(mOutputMax, mOutputMaxPos);
res = outputPair[0];
} else {
if (nullptr == mInputMin) {
// Initial for test
// simulate weight quant
auto weightScale = _Maximum(_ReduceMax(_Abs(mWeight), {1, 2, 3}, true), _Scalar<float>(1E-6)) * _Reciprocal(mWeightClampValue);
auto weightTemp = clamp(_Round(mWeight * _Reciprocal(weightScale)), mWeightClampValue) * weightScale;
auto x = _Convert(inputs[0], NCHW);
auto inputPair = fakeQuantFeatureWithMinMax(x, nullptr, nullptr, mInputClampValue);
mInputMin = updateParameter(mInputMin, inputPair[1]);
mInputMax = updateParameter(mInputMax, inputPair[2]);
setParameter(mInputMin, mInputMinPos);
setParameter(mInputMax, mInputMaxPos);
VARP simuRes;
if (mWinogradAttr != nullptr && bestWinogradUnit(x)) {
simuRes = _winogradConv(x, weightTemp);
} else {
simuRes = _Conv(weightTemp, mBias, _Convert(inputPair[0], NC4HW4), mOption.padMode, mOption.stride,
mOption.dilate, mGroup, mOption.pads);
}
if (mBatchNorm) {
simuRes = mBatchNorm->forward(simuRes);
}
simuRes = _activate(simuRes, mActivation);
Variable::prepareCompute({simuRes});
auto outputPair = fakeQuantFeatureWithMinMax(simuRes, nullptr, nullptr, mOutputClampValue);
mOutputMin = updateParameter(mOutputMin, outputPair[1]);
mOutputMax = updateParameter(mOutputMax, outputPair[2]);
setParameter(mOutputMin, mOutputMinPos);
setParameter(mOutputMax, mOutputMaxPos);
}
// fold bn to conv weights and bias
VARP fusedWeights = mWeight;
VARP fusedBias = mBias;
fusedBias = _Reshape(fusedBias, {static_cast<int>(fusedBias->getInfo()->size), 1, 1, 1});
if (mBatchNorm) {
auto bn = std::static_pointer_cast<BatchNormModule>(mBatchNorm);
auto bnMean = bn->runningMean();
auto bnVar = bn->runningVariance();
auto bnScale = bn->scale();
auto bnBias = bn->bias();
auto bnEps = bn->eps();
MNN_ASSERT(bnMean->getInfo()->dim.size() == 4);
auto rStd = _Const(1.0f) / _Sqrt(bnVar + _Const(bnEps));
auto alpha = rStd * bnScale;
auto beta = bnBias - bnMean * rStd * bnScale;
alpha = _Reshape(alpha, {static_cast<int>(alpha->getInfo()->size), 1, 1, 1});
beta = _Reshape(beta, {static_cast<int>(beta->getInfo()->size), 1, 1, 1});
fusedWeights = alpha * fusedWeights;
fusedBias = alpha * fusedBias + beta;
}
auto x = _Convert(inputs[0], NC4HW4);
int8_t inputZeroPoint, outputZeroPoint;
{
VARP channelScale, zeroPoint;
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mInputMin, mInputMax, mInputClampValue);
mInputScale = scaleAndZeroPoint.first;
mInputZeroPoint = scaleAndZeroPoint.second;
// always PerTensor
channelScale = _Reciprocal(mInputScale);
zeroPoint = _Cast<int8_t>(mInputZeroPoint);
inputZeroPoint = zeroPoint->readMap<int8_t>()[0];
x = _FloatToInt8(x, channelScale, -int8_t(mInputClampValue->readMap<float>()[0]), int8_t(mInputClampValue->readMap<float>()[0]), inputZeroPoint);
}
{
VARP channelScale, zeroPoint;
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mOutputMin, mOutputMax, mOutputClampValue);
mOutputScale = scaleAndZeroPoint.first;
mOutputZeroPoint = scaleAndZeroPoint.second;
// always PerTensor
channelScale = mOutputScale;
zeroPoint = _Cast<int8_t>(mOutputZeroPoint);
outputZeroPoint = zeroPoint->readMap<int8_t>()[0];
}
std::vector<int8_t> weight;
std::vector<float> bias;
std::vector<float> weightScaleVector;
{
VARP weightScale, quanWeight, convScale;
// auto newWeight = fusedWeights * mInputScale;
weightScale = _Maximum(_ReduceMax(_Abs(fusedWeights), {1, 2, 3}, true), _Scalar<float>(1E-6)) * mLimitScale;
quanWeight = _Cast<int8_t>(clamp(_Round(fusedWeights * _Reciprocal(weightScale)), mWeightClampValue));
convScale = _Reciprocal(mOutputScale) * weightScale * mInputScale;
Variable::prepareCompute({quanWeight, convScale});
// // reference for how to get quantized bias
// auto remains = _ReduceSum(_Cast<int32_t>(mInputZeroPoint) * _Cast<int32_t>(quanWeight), {1, 2, 3}, true);
// MNN_ASSERT((mOutputZeroPoint->getInfo()->dim.size() == 0) && (mOutputZeroPoint->getInfo()->size == 1)); // only support per-tensor, per-channel is removed.
// auto outputZeroPointFused = _Cast<int32_t>(_Cast<float>(mOutputZeroPoint) * _Reciprocal(convScale));
// auto quanBias = _Cast<int32_t>(fusedBias * _Reciprocal(weightScale * mInputScale)) - remains + outputZeroPointFused;
{
auto info = quanWeight->getInfo();
weight.resize(info->size);
auto ptr = quanWeight->readMap<int8_t>();
::memcpy(weight.data(), ptr, weight.size() * sizeof(int8_t));
}
{
auto biasinfo = fusedBias->getInfo();
bias.resize(biasinfo->size);
auto ptr = fusedBias->readMap<float>();
::memcpy(bias.data(), ptr, bias.size() * sizeof(float));
auto info = weightScale->getInfo();
weightScaleVector.resize(info->size);
MNN_ASSERT(weightScaleVector.size() == bias.size());
auto ptrScale = weightScale->readMap<float>();
::memcpy(weightScaleVector.data(), ptrScale, weightScaleVector.size() * sizeof(float));
}
}
bool relu = mActivation == NN::None ? false : true;
res = _Conv(std::move(weight), std::move(bias), std::move(weightScaleVector), _Convert(x, NC4HW4), mOption.channel,
mOption.kernelSize, mOption.padMode, mOption.stride, mOption.dilate, mGroup, mOption.pads, relu,
mInputScale->readMap<float>()[0], mOutputScale->readMap<float>()[0],
inputZeroPoint, outputZeroPoint,
-int8_t(mOutputClampValue->readMap<float>()[0]), int8_t(mOutputClampValue->readMap<float>()[0]), mWeightClampValue->readMap<float>()[0], mAccumulateToInt16);
if (mWinogradAttr != nullptr && !mWinogradAttr->attrs.empty()) {
auto scaleAndZeroPoint = computeScaleAndZeroPoint(mWinogradTransInputMin, mWinogradTransInputMax, mInputClampValue);
auto inputScaleVar = scaleAndZeroPoint.first;
auto inputZeroPointVar = scaleAndZeroPoint.second;
auto weightScaleVar = parameters()[mWinogradTransWeightScalePos];
// Winograd Transformed input scale
auto inputScaleInfo = inputScaleVar->getInfo();
auto inputScaleData = inputScaleVar->readMap<float>();
if (inputScaleInfo == nullptr || inputScaleData == nullptr) {
MNN_ERROR("Error for WinogradConvModule, trans input scale not ready\n");
return {};
}
std::vector<float> inputScales(inputScaleData, inputScaleData + inputScaleInfo->size);
// Winograd Transformed input zero point
inputZeroPointVar = _Cast<int32_t>(inputZeroPointVar);
auto inputZeroPointInfo = inputZeroPointVar->getInfo();
auto inputZeroPointData = inputZeroPointVar->readMap<int32_t>();
if (inputZeroPointInfo == nullptr || inputZeroPointData == nullptr) {
MNN_ERROR("Error for WinogradConvModule, trans input zero point not ready\n");
return {};
}
std::vector<int32_t> inputZeroPoints(inputZeroPointData, inputZeroPointData + inputZeroPointInfo->size);
// Winograd Transformed weight scale
auto weightScaleInfo = weightScaleVar->getInfo();
auto weightScaleData = weightScaleVar->readMap<float>();
if (weightScaleInfo == nullptr || weightScaleData == nullptr) {
MNN_ERROR("Error for WinogradConvModule, trans input scale not ready\n");
return {};
}
std::vector<float> weightScales(weightScaleData, weightScaleData + weightScaleInfo->size);
mWinogradAttr->attrs[0].inputScales = inputScales;
mWinogradAttr->attrs[0].inputZeroPoints = inputZeroPoints;
mWinogradAttr->attrs[0].weightScales = weightScales;
res = mWinogradAttr->turnToWinogradConv(res);
}
res->setName(name());
// always PerTensor
res = _Int8ToFloat(res, mOutputScale, outputZeroPoint);
}
return {res};
}
private:
ConvBNReluFusedModule() = default;
Module* clone(CloneContext* ctx) const override {
ConvBNReluFusedModule* module(new ConvBNReluFusedModule);
module->mConvParameter = mConvParameter;
module->mConvParameter.weight = ctx->getOrClone(mConvParameter.weight);
module->mConvParameter.bias = ctx->getOrClone(mConvParameter.bias);
module->mOption = mOption;
module->mGroup = mGroup;
module->mWeight = ctx->getOrClone(mWeight);
module->mBias = ctx->getOrClone(mBias);
module->mActivation = mActivation;
module->mBits = mBits;
module->mLimit = mLimit;
module->mLimitScale = ctx->getOrClone(mLimitScale);
module->mWeightClampValue = ctx->getOrClone(mWeightClampValue);
module->mInputScale = ctx->getOrClone(mInputScale);
module->mOutputScale = ctx->getOrClone(mOutputScale);
module->mInputMin = ctx->getOrClone(mInputMin);
module->mInputMax = ctx->getOrClone(mInputMax);
module->mOutputMin = ctx->getOrClone(mOutputMin);
module->mOutputMax = ctx->getOrClone(mOutputMax);
module->mInputZeroPoint = ctx->getOrClone(mInputZeroPoint);
module->mOutputZeroPoint = ctx->getOrClone(mOutputZeroPoint);
module->mInputMinPos = mInputMinPos;
module->mInputMaxPos = mInputMaxPos;
module->mOutputMinPos = mOutputMinPos;
module->mOutputMaxPos = mOutputMaxPos;
module->mInputClampValue = ctx->getOrClone(mInputClampValue);
module->mOutputClampValue = ctx->getOrClone(mOutputClampValue);
module->mMomentum = mMomentum;
module->mFeatureScaleStatMethod = mFeatureScaleStatMethod;
module->mScaleUpdateMethod = mScaleUpdateMethod;
if (mBatchNorm) {
module->mBatchNorm.reset(mBatchNorm->clone(ctx));
module->registerModel({module->mBatchNorm});
}
module->mWinogradAttr = mWinogradAttr;
module->mWinogradTransInputMin = ctx->getOrClone(mWinogradTransInputMin);
module->mWinogradTransInputMax = ctx->getOrClone(mWinogradTransInputMax);
module->mWinogradTransInputMinPos = mWinogradTransInputMinPos;
module->mWinogradTransInputMaxPos = mWinogradTransInputMaxPos;
module->mWinogradTransWeightScalePos = mWinogradTransWeightScalePos;
return this->cloneBaseTo(ctx, module);
}
NN::ConvParameters mConvParameter;
NN::ConvOption mOption;
int mGroup;
VARP mWeight;
VARP mBias;
NN::ActivationFunctionType mActivation = NN::ActivationFunctionType::None;
std::shared_ptr<Module> mBatchNorm = nullptr;
int mBits;
float mLimit;
VARP mLimitScale;
Express::VARP mWeightClampValue;
VARP mInputScale = nullptr;
VARP mOutputScale = nullptr;
VARP mInputMin = nullptr;
VARP mInputMax = nullptr;
VARP mOutputMin = nullptr;
VARP mOutputMax = nullptr;
VARP mInputZeroPoint = nullptr;
VARP mOutputZeroPoint = nullptr;
int mInputMinPos = -1;
int mInputMaxPos = -1;
int mOutputMinPos = -1;
int mOutputMaxPos = -1;
VARP mInputClampValue;
VARP mOutputClampValue;
float mMomentum = 0.99f;
NN::FeatureScaleStatMethod mFeatureScaleStatMethod;
NN::ScaleUpdateMethod mScaleUpdateMethod;
bool mAccumulateToInt16 = false;
std::shared_ptr<WinogradInt8Attr> mWinogradAttr;
VARP mWinogradTransInputMin = _Const(-100.f);
VARP mWinogradTransInputMax = _Const(-100.f);
int mWinogradTransInputMinPos = -1;
int mWinogradTransInputMaxPos = -1;
int mWinogradTransWeightScalePos = -1;
};
Module* NN::ConvBNReluFused(std::vector<std::shared_ptr<Module> > modules,
NN::FeatureScaleStatMethod featureScaleStatMethod,
NN::ScaleUpdateMethod scaleUpdateMethod, const int bits, bool winograd) {
return new ConvBNReluFusedModule(modules, featureScaleStatMethod, scaleUpdateMethod, bits, winograd);
}
Module* NN::ConvInt8(const ConvOption& option, int bits, bool hasBias,
std::shared_ptr<Initializer> weightInit, std::shared_ptr<Initializer> biasInit, NN::FeatureScaleStatMethod featureMethod, NN::ScaleUpdateMethod method) {
std::shared_ptr<Module> conv(NN::Conv(option));
return new ConvBNReluFusedModule({conv}, featureMethod, method, bits);
}
Module* NN::ConvInt8(const ConvParameters& para, int bits, NN::FeatureScaleStatMethod featureMethod, NN::ScaleUpdateMethod method) {
std::shared_ptr<Module> conv(NN::Conv(para));
return new ConvBNReluFusedModule({conv}, featureMethod, method, bits);
}
bool NN::turnQuantize(Module* module, const int bits, NN::FeatureScaleStatMethod featureScaleStatMethod, NN::ScaleUpdateMethod scaleUpdateMethod, bool winogradOpt) {
if (nullptr == module || module->type() != PIPELINE_MODULE) {
MNN_ERROR("Invalide module for quantized\n");
return false;
}
auto pipModule = static_cast<PipelineModule*>(module);
std::vector<int> needEraseIndices;
for (int i = 0; i < pipModule->mSubModules.size(); i++) {
auto& m = pipModule->mSubModules[i];
auto& theModule = std::get<0>(m);
auto moduleType = theModule->type();
//auto& inputIndices = std::get<1>(m);
auto& outputIndices = std::get<2>(m);
if (moduleType == "Conv" && i < pipModule->mSubModules.size() - 1) {
auto& p1 = pipModule->mSubModules[i+1];
auto p1Module = std::get<0>(p1);
auto& p1ModuleType = p1Module->type();
auto& p1InputIndices = std::get<1>(p1);
auto& p1OutputIndices = std::get<2>(p1);
auto convOutputCount = pipModule->countOutputReference(outputIndices);
bool convSingleOutputReference = ((outputIndices.size() == 1) && (convOutputCount[0] == 1));
// only conv
if ((!convSingleOutputReference) || (p1ModuleType == "Conv") ||
(p1ModuleType != "BatchNorm" && p1ModuleType != "ReLU" && p1ModuleType != "ReLU6")) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
continue;
}
// conv + bn + ?
if (p1ModuleType == "BatchNorm") {
bool convBnConnected = ((convSingleOutputReference) && (p1InputIndices.size() == 1) && (p1InputIndices[0] == outputIndices[0]));
if (!convBnConnected) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
continue;
}
// last conv + bn
if (i == pipModule->mSubModules.size() - 2) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
// maybe there is a relu or relu6 after conv + bn
auto& p2 = pipModule->mSubModules[i+2];
auto& p2Module = std::get<0>(p2);
auto p2ModuleType = p2Module->type();
auto& p2InputIndices = std::get<1>(p2);
auto& p2OutputIndices = std::get<2>(p2);
auto bnOutputCount = pipModule->countOutputReference(p1OutputIndices);
bool bnSingleOutputReference = ((p1OutputIndices.size() == 1) && (bnOutputCount[0] == 1));
// only conv + bn
if ((!bnSingleOutputReference) || (p2ModuleType != "ReLU" && p2ModuleType != "ReLU6")) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
} else { // conv + bn + relu or conv + bn + relu6
bool convBnReluConnected = ((bnSingleOutputReference) && (p2InputIndices.size() == 1) && (p2InputIndices[0] == p1OutputIndices[0]));
bool isPrelu = false;
if (p2ModuleType == "ReLU") {
auto p2Op = ((ExprModule*)p2Module.get())->getExpr()->get();
float slope = p2Op->main_as_Relu()->slope();
isPrelu = std::abs(slope) > 1e-6;
}
if (!convBnReluConnected || isPrelu) {
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
theModule.reset(NN::ConvBNReluFused({theModule, p1Module, p2Module}, featureScaleStatMethod, scaleUpdateMethod, bits, winogradOpt));
pipModule->registerModel({theModule});
outputIndices = p2OutputIndices;
needEraseIndices.emplace_back(i + 1);
needEraseIndices.emplace_back(i + 2);
continue;
}
}
// conv + relu or conv + relu6
if (p1ModuleType == "ReLU" || p1ModuleType == "ReLU6") {
bool convReluConnected = ((convSingleOutputReference) && (p1InputIndices.size() == 1) && (p1InputIndices[0] == outputIndices[0]));
bool isPrelu = false;
if (p1ModuleType == "ReLU") {
auto p1Op = ((ExprModule*)p1Module.get())->getExpr()->get();
float slope = p1Op->main_as_Relu()->slope();
isPrelu = std::abs(slope) > 1e-6;
}
if (!convReluConnected || isPrelu) {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
continue;
}
theModule.reset(NN::ConvBNReluFused({theModule, p1Module}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
outputIndices = p1OutputIndices;
needEraseIndices.emplace_back(i + 1);
continue;
}
}
if (i == pipModule->mSubModules.size() - 1 && moduleType == "Conv") {
theModule.reset(NN::ConvBNReluFused({theModule}, featureScaleStatMethod, scaleUpdateMethod, bits));
pipModule->registerModel({theModule});
}
}
// erase useless submodules
const int eraseSize = needEraseIndices.size();
int alreadyErasedCount = 0;
for (int i = 0; i < eraseSize; i++) {
auto position = needEraseIndices[i] - alreadyErasedCount;
auto type = std::get<0>(pipModule->mSubModules[position])->type();
MNN_ASSERT(type == "BatchNorm" || type == "ReLU" || type == "ReLU6");
pipModule->mSubModules.erase(pipModule->mSubModules.begin() + position);
alreadyErasedCount++;
}
return true;
}
Module* NN::extract(std::vector<Express::VARP> inputs, std::vector<Express::VARP> outputs, bool fortrain, const std::map<std::string, SubGraph>& subGraph) {
std::function<std::pair<std::vector<int>, std::shared_ptr<Module>>(EXPRP)> transformFunction;
if (fortrain) {
transformFunction =
[&subGraph](EXPRP source) {
if (source->get() == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> m(NN::Utils::ExtractNotRunableOp(source, subGraph));
if (nullptr != m) {
m->setName(source->name());
return std::make_pair(std::vector<int>{}, m);
}
auto convExtracted = NN::Utils::ExtractConvolution(source);
if (convExtracted.weight == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> module(NN::Conv(convExtracted));
module->setName(source->name());
return std::make_pair(std::vector<int>{0}, module);
};
} else {
transformFunction = [&subGraph](EXPRP source) {
if (source->get() == nullptr) {
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
}
std::shared_ptr<Module> m(NN::Utils::ExtractNotRunableOp(source, subGraph));
if (nullptr != m) {
m->setName(source->name());
return std::make_pair(std::vector<int>{}, m);
}
return std::make_pair(std::vector<int>{}, std::shared_ptr<Module>(nullptr));
};
}
return new PipelineModule(inputs, outputs, transformFunction);
}
} // namespace Express
} // namespace MNN