tools/converter/source/common/writeFb.cpp (289 lines of code) (raw):

// // writeFb.cpp // MNNConverter // // Created by MNN on 2019/01/31. // Copyright © 2018, Alibaba Group Holding Limited // #include <fstream> #include <iostream> #include <algorithm> #include <set> #include <string> #include <sstream> #include "MNN_generated.h" #include "core/MNNFileUtils.h" #include "logkit.h" #include "writeFb.hpp" #include "CommonUtils.hpp" #include "cpp/ConfigFile.hpp" #include <MNN/MNNDefine.h> #include "cli.hpp" #include "MNN_compression.pb.h" using namespace MNN; using namespace std; static void _postTreatOp(std::unique_ptr<OpT>& op, FileLoader* fl, const PostTreatContext& context, const modelConfig& config, std::ofstream& weightPath, int64_t& offset, bool needExternalWeight) { loadExternalParam(op, fl); if (config.alignDenormalizedValue) { AlignDenormalizedValue(op); } if (config.saveHalfFloat) { CastParamsToHalf(op); } if (config.detectSparseSpeedUp) { AddSparseInfo(op, context.proto); } WeightQuantAndCoding(op, config, &context); if (needExternalWeight) { RemoveAndStoreParam(op, &weightPath, offset); } } static float _computeOpExternalSizeInMB(const MNN::OpT* op) { switch (op->main.type) { case MNN::OpParameter_Convolution2D: { auto conv2D = op->main.AsConvolution2D(); if (conv2D->external.empty()) { return 0.0f; } return ((float)conv2D->external[1] + (float)conv2D->external[2]) / 1024.0f / 1024.0f; } case MNN::OpParameter_Blob: { auto blob = op->main.AsBlob(); if (blob->external.empty()) { return 0.0f; } return blob->external[1] / 1024.0f / 1024.0f; } default: break; } return 0.0f; } static bool _largeModel(const MNN::NetT* netT) { float summer = 0.0f; for (auto& op : netT->oplists) { summer+= _computeOpExternalSizeInMB(op.get()); if (summer > 2000.0f) { MNN_PRINT("Model larger than 2GB\n"); return true; } } for (auto& subgraph : netT->subgraphs) { for (auto& op : subgraph->nodes) { summer+= _computeOpExternalSizeInMB(op.get()); if (summer > 2000.0f) { MNN_PRINT("Model larger than 2GB\n"); return true; } } } return false; } int postTreat(std::unique_ptr<MNN::NetT>& netT, const modelConfig& config) { std::string compressFileName = config.compressionParamsFile; auto& proto = config.compressInfo->proto; auto MNNModelFile = config.MNNModel; addUUID(netT, config.compressInfo->proto); bool useOriginQuant = config.compressInfo->proto.algo_size() > 0 && (!config.compressInfo->write); if (config.compressInfo->proto.has_for_guide() && config.compressInfo->proto.for_guide()) { useOriginQuant = false; } if (useOriginQuant) { channelPruneConvert(netT, proto); } if (useOriginQuant) { fullQuantAndCoding(netT, proto); } // Check If need external weight bool needExternalWeight = config.saveExternalData; if ((!needExternalWeight) && config.model != modelConfig::MNN) { needExternalWeight = _largeModel(netT.get()); } std::ofstream externalWeightOs; if (needExternalWeight) { auto weightName = config.MNNModel + ".weight"; MNN_PRINT("Save Weight to %s\n", weightName.c_str()); externalWeightOs.open(weightName.c_str(), ios::binary); if (externalWeightOs.fail()) { MNN_PRINT("Write %s failed\n", weightName.c_str()); } } { bool findQuant = false; auto& context = *config.compressInfo; for (int i=0; i<proto.algo_size(); ++i) { auto& algo = proto.algo(i); if (algo.type() != MNN::Compression::CompressionAlgo_CompressionType_QUANTIZE) { continue; } if (!algo.has_quant_params()) { continue; } findQuant = true; for (int j=0; j<algo.quant_params().layer_size(); ++j) { const auto& quant = algo.quant_params().layer(j); std::string opName; if (!quant.has_op_name()) { continue; } opName = quant.op_name(); std::string graphName; if (quant.has_subgraph_name()) { graphName = quant.subgraph_name(); } context.quantInfo.insert(std::make_pair(std::make_pair(graphName, opName), &quant)); } break; } if ((!findQuant) && context.write) { // Add Quant param for write proto.set_version(MNN_VERSION); proto.set_for_guide(true); auto algo = proto.add_algo(); algo->set_type( MNN::Compression::CompressionAlgo_CompressionType_QUANTIZE); context.quantMutableInfo = algo->mutable_quant_params(); } int64_t offset = 0; FileLoader fl(".__convert_external_data.bin"); for (auto& op : netT->oplists) { _postTreatOp(op, &fl, context, config, externalWeightOs, offset, needExternalWeight); } for (auto& subgraph : netT->subgraphs) { context.subgraph = subgraph->name; for (auto& op : subgraph->nodes) { _postTreatOp(op, &fl, context, config, externalWeightOs, offset, needExternalWeight); } } } { MNNRemoveFile(".__convert_external_data.bin"); } if (config.compressInfo->write) { CommonKit::protobuf2json(compressFileName.c_str(), &proto); } return 0; } int writeFb(std::unique_ptr<MNN::NetT>& netT, const modelConfig& config, std::unique_ptr<MNN::OpT>&& metaOp) { postTreat(netT, config); // Merge Meta to metaOp auto oplist = std::move(netT->oplists); for (auto& op : oplist) { if (op->type == OpType_Extra) { auto dstExtra = metaOp->main.AsExtra(); auto extra = op->main.AsExtra(); if (extra->type == "Meta" && extra->engine == "MNN") { for (auto& attr : extra->attr) { dstExtra->attr.emplace_back(std::move(attr)); } // Remove meta op continue; } } netT->oplists.emplace_back(std::move(op)); } std::set<std::string> notSupportOps; // Detect unsupport op auto CheckIfNotSupported = [&] (const std::unique_ptr<MNN::OpT>& op) { if (op->type == MNN::OpType_Extra) { if (op->main.AsExtra()->engine != "MNN") { notSupportOps.insert(op->main.AsExtra()->engine + "::" + op->main.AsExtra()->type); } } }; for (auto& op : netT->oplists) { CheckIfNotSupported(op); } for (auto& subgraph : netT->subgraphs) { for (auto& op : subgraph->nodes) { CheckIfNotSupported(op); } } std::ostringstream notSupportInfo; if (!notSupportOps.empty() && !config.allowCustomOp) { for (auto name : notSupportOps) { notSupportInfo << name << " | "; } auto opNames = notSupportInfo.str(); LOG(FATAL) << "These Op Not Support: " << opNames.substr(0, opNames.size() - 2); return 1; } // dump input and output tensor name { std::set<int> inputIdx, outputIdx, realOutput; std::vector<int> realInput; for (const auto& op : netT->oplists) { for (auto i : op->inputIndexes) { inputIdx.insert(i); } for (auto o : op->outputIndexes) { outputIdx.insert(o); if (op->type == OpType_Input) { realInput.emplace_back(o); } } } std::set_difference(outputIdx.begin(), outputIdx.end(), inputIdx.begin(), inputIdx.end(), std::inserter(realOutput, realOutput.begin())); std::cout << "inputTensors : [ "; for (int i : realInput) { std::cout << netT->tensorName[i] << ", "; } std::cout << "]\noutputTensors: [ "; if (netT->outputName.size() > 0) { for (auto& o : netT->outputName) { std::cout << o << ", "; } } else { for (int i : realOutput) { std::cout << netT->tensorName[i] << ", "; } } std::cout << "]" << std::endl; } // add version info to model netT->extraInfo.reset(new ExtraInfoT); netT->extraInfo->version = MNN_VERSION; if (!config.authCode.empty()) { // add auth code to model netT->extraInfo->name = config.authCode; } if (metaOp->main.AsExtra()->attr.size() > 0) { flatbuffers::FlatBufferBuilder builder; builder.Finish(MNN::Extra::Pack(builder, metaOp->main.AsExtra())); netT->extraInfo->buffer.resize(builder.GetSize()); ::memcpy(netT->extraInfo->buffer.data(), builder.GetBufferPointer(), builder.GetSize()); } flatbuffers::FlatBufferBuilder builderOutput(1024); builderOutput.ForceDefaults(true); auto len = MNN::Net::Pack(builderOutput, netT.get()); builderOutput.Finish(len); int sizeOutput = builderOutput.GetSize(); auto bufferOutput = builderOutput.GetBufferPointer(); if (config.saveStaticModel && netT->usage != MNN::Usage_INFERENCE_STATIC) { std::map<std::string, std::vector<int>> inputConfig; // get config to set input size if (config.inputConfigFile.size() > 0) { ConfigFile conf(config.inputConfigFile); auto numOfInputs = conf.Read<int>("input_size"); auto inputNames = splitNames(numOfInputs, conf.Read<std::string>("input_names")); auto inputDims = splitDims(numOfInputs, conf.Read<std::string>("input_dims")); for (int i = 0; i < numOfInputs; i++) { inputConfig.insert(std::make_pair(inputNames[i], inputDims[i])); } } const Net* net = flatbuffers::GetRoot<MNN::Net>(bufferOutput); converToStaticModel(net, inputConfig, config.MNNModel); } else { std::ofstream output(config.MNNModel, std::ofstream::binary); output.write((const char*)bufferOutput, sizeOutput); } if (!netT->subgraphs.empty()) { MNN_PRINT("The model has subgraphs, please use MNN::Express::Module to run it\n"); } #ifdef MNN_DUMP_SUBGRAPH for (int i = 0; i < netT->subgraphs.size(); ++i) { std::unique_ptr<MNN::NetT> subnet(new MNN::NetT); auto& subgraph = netT->subgraphs[i]; subnet->oplists = std::move(subgraph->nodes); subnet->tensorName = subgraph->tensors; subnet->sourceType = netT->sourceType; subnet->bizCode = netT->bizCode; flatbuffers::FlatBufferBuilder builder(1024); builder.ForceDefaults(true); auto len = MNN::Net::Pack(builder, subnet.get()); builder.Finish(len); int output_size = builder.GetSize(); auto* output_ptr = builder.GetBufferPointer(); std::string filename = MNNModelFile + "_subgraph_" + std::to_string(i) + ".mnn"; std::ofstream output(filename.c_str(), std::ofstream::binary); output.write((const char*)output_ptr, output_size); } #endif return 0; }