in tools/converter/source/optimizer/onnxextra/OnnxConvolutionMerge.cpp [137:422]
virtual EXPRP onExecute(EXPRP expr) const override {
if (convSpatialDim(expr) == 3) {
return _transformConv3D(expr);
}
auto inputs = expr->inputs();
const int inputSize = static_cast<int32_t>(inputs.size());
auto x = inputs[0];
if (inputSize != 3 && inputSize != 2) {
MNN_ERROR("Convolution Input ERROR!\n");
return nullptr;
}
auto weight = inputs[1];
auto weight_expr = weight->expr().first;
bool weightIden = false;
bool xIden = false;
if (weight_expr->get()) {
weightIden = weight_expr->get()->type() == OpType_Int8ToFloat;
}
if (inputs[0]->expr().first->get()) {
xIden = inputs[0]->expr().first->get()->type() == OpType_Int8ToFloat;
}
if (false == weightIden && nullptr == weight->getInfo()) {
MNN_ERROR("Convolution should know weight shape infromation!\n");
return nullptr;
}
INTS weightShape = weight->getInfo()->dim;
bool convertToConvint8 = false;
auto op = expr->get();
auto extraParam = op->main_as_Extra();
std::string originalOpType(extraParam->type()->c_str());
bool isDeconv = originalOpType == "ConvTranspose";
int co = weightShape[0];
int ci = weightShape[1];
int kh = weightShape[2];
int kw = 1;
if (weightShape.size() >= 4) {
kw = weightShape[3];
}
if (isDeconv) {
co = weightShape[1];
ci = weightShape[0];
}
if (weightIden) {
co = weightShape[1];
ci = weightShape[0];
}
int group = 1;
int dilation_h = 1;
int dilation_w = 1;
int stride_h = 1;
int stride_w = 1;
PadMode modePadding = PadMode_CAFFE;
std::vector<int> outputPadding;
std::vector<int> inputPadding;
std::vector<int> outputShape;
const int attrSize = extraParam->attr()->size();
for (int i = 0; i < attrSize; ++i) {
auto attr = extraParam->attr()->GetAs<Attribute>(i);
const auto& key = attr->key()->str();
if (key == "dilations") {
auto dataList = attr->list();
dilation_h = dataList->i()->data()[0];
if (dataList->i()->size() >= 2) {
dilation_w = dataList->i()->data()[1];
}
} else if (key == "group") {
group = attr->i();
} else if (key == "strides") {
auto dataList = attr->list();
stride_h = dataList->i()->data()[0];
if (dataList->i()->size() >= 2) {
stride_w = dataList->i()->data()[1];
}
} else if (key == "auto_pad") {
if (attr->s()->str() == "NOTSET") {
modePadding = PadMode_CAFFE;
} else if (attr->s()->str() == "SAME_UPPER" || attr->s()->str() == "SAME_LOWER") {
modePadding = PadMode_SAME;
} else if (attr->s()->str() == "VALID") {
modePadding = PadMode_VALID;
} else {
MNN_ERROR("Conv auto_pad not support %s\n", attr->s()->c_str());
return nullptr;
}
} else if (key == "pads") {
auto dataList = attr->list();
inputPadding.resize(dataList->i()->size());
for (int v = 0; v < inputPadding.size(); v++) {
inputPadding[v] = dataList->i()->data()[v];
}
// Support Convolution 1D
if (inputPadding.size() == 2) {
inputPadding = {inputPadding[0], 0, inputPadding[1], 0};
}
} else if (key == "output_padding") {
// only valid in ConvTranspose
auto dataList = attr->list();
const int size = dataList->i()->size();
for (int k = 0; k < size; ++k) {
outputPadding.push_back(dataList->i()->data()[k]);
}
if (outputPadding.size() == 1) {
outputPadding = {outputPadding[0], 0};
}
} else if (key == "output_shape") {
auto dataList = attr->list();
outputShape.resize(dataList->i()->size());
for (int v = 0; v < outputShape.size(); v++) {
outputShape[v] = dataList->i()->data()[v];
}
}
}
std::unique_ptr<Convolution2DT> convParam(new MNN::Convolution2DT);
convParam->common.reset(new MNN::Convolution2DCommonT);
auto common = convParam->common.get();
// For old mnn compability
if (inputPadding.size() >= 4) {
common->padY = inputPadding[0];
common->padX = inputPadding[1];
}
common->padMode = modePadding;
// set param
common->relu = false;
common->group = group;
if (isDeconv) {
common->outputCount = co * group; // deconv set inputCount to be ci, dw to be group
common->inputCount = ci;
} else {
common->outputCount = co;
common->inputCount = ci * group; // conv set inputCount to be ci, dw to be group
}
common->kernelX = kw;
common->kernelY = kh;
common->dilateX = dilation_w;
common->dilateY = dilation_h;
common->strideX = stride_w;
common->strideY = stride_h;
common->pads = inputPadding;
common->outPads = outputPadding;
if (!outputShape.empty()) {
common->hasOutputShape = true;
common->padMode = PadMode_SAME;
}
auto config = Global<modelConfig>::Get();
// read weight data
const float* weightDataPtr = nullptr;
int limitNumber = 4;
if (config->optimizePrefer == 1) {
// Smallest
limitNumber = 1;
} else if (config->optimizePrefer == 2) {
// Fastest
limitNumber = 100;
}
VARP wf = weight;
if ( weight->linkNumber() <= limitNumber && !convertToConvint8) {
if (!weightIden) {
weightDataPtr = weight->readMap<float>();
}
else {
auto yy = weight->expr().first->inputs()[0]; // weight shape: [ic,oc,kh,kw]
auto ss = _Const(weight->expr().first->get()->main_as_QuantizedFloatParam()->tensorScale()->data(), {co});
auto zz = _Const(weight->expr().first->get()->main_as_QuantizedFloatParam()->floatzeros()->data(), {co});
wf = (_Cast<float>(_Permute(yy, {0, 2, 3, 1})) - zz) * ss;
wf = _Permute(wf, {3, 0, 1, 2});
weightDataPtr = wf->readMap<float>();
}
}
EXPRP reluExpr;
bool hasRelu = false;
if (weightDataPtr) {
if (weight->linkNumber() > 1) {
static bool gPrint = false;
if (!gPrint) {
MNN_PRINT("The Convolution use shared weight, may increase the model size\n");
gPrint = true;
}
}
// MNN_PRINT("MNNCountNNZBlock:%p\n", MNNCountNNZBlock);
const size_t weightSize = co * ci * kh * kw;
convParam->weight.resize(weightSize);
::memcpy(convParam->weight.data(), weightDataPtr, weightSize * sizeof(float));
convParam->bias.resize(common->outputCount);
if (inputSize == 3) {
// read bias data
auto bias = inputs[2];
const int biasNums = bias->getInfo()->size;
if (biasNums != common->outputCount) {
// TODO broacast
MNN_ERROR("[TODO] Conv's bias support broadcast!\n");
return nullptr;
}
auto biasDataPtr = bias->readMap<float>();
if (!biasDataPtr) {
MNN_ERROR("Conv's bias input should be Constant!\n");
return nullptr;
}
::memcpy(convParam->bias.data(), biasDataPtr, common->outputCount * sizeof(float));
} else {
::memset(convParam->bias.data(), 0, common->outputCount * sizeof(float));
}
}
std::unique_ptr<OpT> newOp(new OpT);
newOp->name = expr->name();
if (isDeconv) {
newOp->type = OpType_Deconvolution;
if (group > 1 && group == ci * co) {
newOp->type = OpType_DeconvolutionDepthwise;
}
} else {
newOp->type = OpType_Convolution;
if (group > 1 && group == ci * co) {
newOp->type = OpType_ConvolutionDepthwise;
}
}
if (!isDeconv && true == weightIden && true == xIden && weight_expr->inputs().size() == 5) {
newOp->type = OpType_ConvInt8;
if (common->inputCount == common->outputCount && common->outputCount == common->group) {
newOp->type = OpType_DepthwiseConvInt8;
}
}
newOp->main.type = OpParameter_Convolution2D;
newOp->main.value = convParam.release();
bool needSqueeze = false;
if (nullptr != x->getInfo()) {
if (x->getInfo()->dim.size() == 3) {
x = _Unsqueeze(x, {3});
needSqueeze = true;
}
}
EXPRP convolutionExpr;
if (!outputShape.empty()) {
// [1, outputHeight, outputWidth, 1]
outputShape.insert(outputShape.begin(), 1);
outputShape.push_back(1);
auto output_shape = _Const(outputShape.data(), {4}, NHWC, halide_type_of<int>());
if (weightDataPtr || convertToConvint8) {
// merge weight(bias) node to Conv parameter
convolutionExpr = Expr::create(newOp.get(), {x, output_shape});
} else {
// construct bias input, because mnn runtime constrain that conv should have 3 inputs when weight is not
// Constant
if (inputs.size() > 2) {
convolutionExpr = Expr::create(newOp.get(), {x, inputs[1], inputs[2], output_shape});
} else {
convolutionExpr = Expr::create(newOp.get(), {x, inputs[1], output_shape});
}
}
} else if (weightDataPtr || convertToConvint8) {
// merge weight(bias) node to Conv parameter
convolutionExpr = Expr::create(newOp.get(), {x});
} else {
// construct bias input, because mnn runtime constrain that conv should have 3 inputs when weight is not
// Constant
if (inputs.size() > 2) {
convolutionExpr = Expr::create(newOp.get(), {x, inputs[1], inputs[2]});
} else {
convolutionExpr = Expr::create(newOp.get(), {x, inputs[1]});
}
}
convolutionExpr->setName(expr->name());
auto res = Variable::create(convolutionExpr);
if (needSqueeze) {
res = _Squeeze(res, {3});
}
return res->expr().first;
}