in tools/converter/source/optimizer/postconvert/AddTensorFormatConverter.cpp [319:602]
virtual bool onExecute(std::unique_ptr<MNN::NetT>& net) const override {
auto& mNet = net;
if (mNet->sourceType == MNN::NetSource_CAFFE) {
return true;
}
auto* ctx = Global<MNN::Express::OptimizeContext>::Get();
auto originTensorType = MNN::MNN_DATA_FORMAT_NHWC;
if (mNet->sourceType == MNN::NetSource_ONNX || mNet->sourceType == MNN::NetSource_TORCH) {
originTensorType = MNN::MNN_DATA_FORMAT_NCHW;
}
for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
auto op = iter->get();
if (OpParameter_Blob == op->main.type) {
if (op->main.AsBlob()->dataFormat != MNN_DATA_FORMAT_NC4HW4) {
op->main.AsBlob()->dataFormat = originTensorType;
}
}
if (OpParameter_Reshape == op->main.type) {
op->main.AsReshape()->dimType = originTensorType;
}
}
auto config = Global<modelConfig>::Get();
auto version = config->targetVersion;
// Compute All Tensor's format
std::vector<MNN_DATA_FORMAT> tensorFormats(net->tensorName.size());
std::vector<bool> readyMask(net->oplists.size());
std::fill(tensorFormats.begin(), tensorFormats.end(), MNN_DATA_FORMAT_UNKNOWN);
std::fill(readyMask.begin(), readyMask.end(), false);
bool hasChange = false;
bool complete = false;
// Record Const Op Index
std::vector<int32_t> constTensorIndexs;
do {
complete = true;
hasChange = false;
for (int i=0; i<readyMask.size(); ++i) {
if (readyMask[i]) {
continue;
}
auto op = net->oplists[i].get();
readyMask[i] = _computeTensorFormat(tensorFormats, constTensorIndexs, op, originTensorType, config->keepInputFormat, false);
if (readyMask[i]) {
hasChange = true;
} else {
complete = false;
}
}
} while (hasChange);
// Has can't determine one, force compability op use originFormat
if (!complete) {
for (int i=0; i<readyMask.size(); ++i) {
if (readyMask[i]) {
continue;
}
auto op = net->oplists[i].get();
readyMask[i] = _computeTensorFormat(tensorFormats, constTensorIndexs, op, originTensorType, config->keepInputFormat, true);
MNN_ASSERT(readyMask[i] == true);
}
}
// Insert Extra Converter
std::map<int, int> convertMap;
if (config->keepInputFormat) {
// Change Output
auto& outputs = mNet->outputName;
std::vector<std::unique_ptr<MNN::OpT>> extraOp;
for (auto& op : mNet->oplists) {
for (int idx : op->outputIndexes) {
for (int j = 0; j < outputs.size(); j++) {
if (mNet->tensorName[idx] == outputs[j]) {
auto outputFormat = tensorFormats[idx];
if (outputFormat == MNN_DATA_FORMAT_NC4HW4) {
auto newOutputName = outputs[j] + "__before_tr";
mNet->tensorName[idx] = newOutputName;
// Append a convert op
MNN::OpT* transformOp = new MNN::OpT;
MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT;
tc->source = outputFormat;
tc->dest = originTensorType;
transformOp->main.type = MNN::OpParameter_TensorConvertInfo;
transformOp->main.value = tc;
transformOp->name = newOutputName;
transformOp->inputIndexes.push_back(idx);
int newOutputIndex = (int)mNet->tensorName.size();
transformOp->outputIndexes.push_back(newOutputIndex);
tensorFormats.push_back(originTensorType);
mNet->tensorName.push_back(outputs[j]);
transformOp->type = MNN::OpType_ConvertTensor;
extraOp.emplace_back(transformOp);
}
}
}
}
}
for (auto&& op : extraOp) {
mNet->oplists.emplace_back(std::move(op));
}
} else {
// Change Input
for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end(); iter++) {
auto& op = *iter;
if (OpType_Input == op->type) {
auto originInputFormat = op->main.AsInput()->dformat;
op->main.AsInput()->dformat = tensorFormats[op->outputIndexes[0]];
if (originInputFormat == MNN_DATA_FORMAT_NHWC && op->main.AsInput()->dformat == MNN_DATA_FORMAT_NC4HW4 && op->main.AsInput()->dims.size() == 4 && ctx->first_run) {
int n = op->main.AsInput()->dims[0];
int h = op->main.AsInput()->dims[1];
int w = op->main.AsInput()->dims[2];
int c = op->main.AsInput()->dims[3];
op->main.AsInput()->dims = {n, c, h, w};
}
}
}
}
if (originTensorType == MNN_DATA_FORMAT_NHWC) {
for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) {
auto op = iter->get();
// Insert Pretreat Op if needed
if (op->type == OpType_Padding && tensorFormats[op->outputIndexes[0]] == MNN_DATA_FORMAT_NC4HW4 && ctx->first_run) {
const int padValueIndex = op->inputIndexes[1];
auto padValueOp = PostTreatUtils::_findOpByOutputIndex(padValueIndex, mNet.get());
// Add Gather op for padding, turn nhwc -> nchw
std::unique_ptr<OpT> gatherIndex(new OpT);
gatherIndex->outputIndexes = {(int)mNet->tensorName.size()};
gatherIndex->type = OpType_Const;
gatherIndex->name = op->name + "_Gather_Index";
mNet->tensorName.emplace_back(gatherIndex->name);
tensorFormats.push_back(originTensorType);
gatherIndex->main.type = OpParameter_Blob;
gatherIndex->main.value = new BlobT;
gatherIndex->main.AsBlob()->dataType = DataType_DT_INT32;
gatherIndex->main.AsBlob()->dataFormat = originTensorType;
gatherIndex->main.AsBlob()->int32s = {0, 3, 1, 2};
gatherIndex->main.AsBlob()->dims = {4};
std::unique_ptr<OpT> gather(new OpT);
gather->outputIndexes = {(int)mNet->tensorName.size()};
gather->inputIndexes = {op->inputIndexes[1], gatherIndex->outputIndexes[0]};
gather->type = OpType_GatherV2;
gather->name = op->name + "_Gather";
mNet->tensorName.emplace_back(gather->name);
tensorFormats.push_back(originTensorType);
op->inputIndexes[1] = gather->outputIndexes[0];
iter = mNet->oplists.insert(iter, std::move(gather));
iter = mNet->oplists.insert(iter, std::move(gatherIndex));
iter++;
iter++;
iter++;
} else {
iter++;
}
}
}
for (auto iter = mNet->oplists.begin(); iter != mNet->oplists.end();) {
auto& op = *iter;
if (op->inputIndexes.empty()) {
iter++;
continue;
}
if (!_OpNeedConvertContent(op->type)) {
iter++;
continue;
}
auto formatType = _getFormatType(op.get(), originTensorType);
std::vector<MNN::OpT*> transformOps;
auto currentName = op->name;
for (int i = 0; i < op->inputIndexes.size(); ++i) {
auto inputIndex = op->inputIndexes[i];
if (inputIndex < 0) {
continue; // optional input, ignore it
}
auto type = tensorFormats[inputIndex];
auto requireType = _getRequireFormat(formatType, i, tensorFormats[op->outputIndexes[0]], originTensorType);
if (type == requireType) {
continue;
}
if (convertMap.find(op->inputIndexes[i]) != convertMap.end()) {
op->inputIndexes[i] = convertMap[op->inputIndexes[i]];
continue;
}
// Insert Transform op
MNN::OpT* transformOp = new MNN::OpT;
transformOps.push_back(transformOp);
MNN::TensorConvertInfoT* tc = new MNN::TensorConvertInfoT;
tc->source = type;
tc->dest = requireType;
transformOp->main.type = MNN::OpParameter_TensorConvertInfo;
transformOp->main.value = tc;
transformOp->name = mNet->tensorName[inputIndex] + "___tr4" + op->name;
// printf("Insert convert for %s, %s 's input %d\n", net->tensorName[inputIndex].c_str(),
// op->name.c_str(), i);
transformOp->inputIndexes.push_back(inputIndex);
transformOp->outputIndexes.push_back(mNet->tensorName.size());
convertMap[inputIndex] = transformOp->outputIndexes[0];
tensorFormats.push_back(requireType);
mNet->tensorName.push_back(transformOp->name);
op->inputIndexes[i] = transformOp->outputIndexes[0];
transformOp->type = MNN::OpType_ConvertTensor;
}
for (int i = transformOps.size() - 1; i >= 0; i--) {
iter = mNet->oplists.insert(iter, std::unique_ptr<MNN::OpT>(transformOps[i]));
}
for (; (*iter)->name != currentName; iter++) {
}
iter++;
}
if (originTensorType == MNN_DATA_FORMAT_NCHW) {
return true;
}
// For NHWC -> NC4HW4 op, should Reset axis map
const int axisMap[4] = {0, 2, 3, 1};
for (auto& op : mNet->oplists) {
if (op->inputIndexes.empty()) {
continue;
}
if (tensorFormats[op->outputIndexes[0]] != MNN_DATA_FORMAT_NC4HW4) {
continue;
}
if (!ctx->first_run) {
continue;
}
if (MNN::OpType_Input == op->type) {
auto input = op->main.AsInput();
const int dimSize = input->dims.size();
if (dimSize > 2) {
const int channel = input->dims[dimSize - 1];
for (int i = dimSize - 1; i > 1; --i) {
input->dims[i] = input->dims[i - 1];
}
input->dims[1] = channel;
}
}
if (MNN::OpType_Concat == op->type) {
auto axis = op->main.AsAxis();
auto concatAxis = axis->axis;
if (concatAxis < 0) {
concatAxis = 4 + concatAxis;
}
DCHECK(concatAxis >= 0 && concatAxis <= 3) << "Concat axis ERROR!";
axis->axis = axisMap[concatAxis];
}
if (MNN::OpType_Permute == op->type) {
auto permuteT = op->main.AsPermute();
for (int i = 0; i < permuteT->dims.size(); ++i) {
DCHECK(permuteT->dims[i] >= 0 && permuteT->dims[i] <= 3) << "Dim Error ==> " << op->name;
permuteT->dims[i] = axisMap[permuteT->dims[i]];
}
}
if (MNN::OpType_Slice == op->type) {
auto slice = op->main.AsSlice();
auto sliceAxis = slice->axis;
if (sliceAxis < 0) {
sliceAxis = 4 + sliceAxis;
}
DCHECK(sliceAxis >= 0 && sliceAxis <= 3) << "Slice axis ERROR!";
slice->axis = axisMap[sliceAxis];
}
if (MNN::OpType_Reshape == op->type) {
auto reshape = op->main.AsReshape();
auto originDim = reshape->dims;
for (int i = 0; i < reshape->dims.size(); ++i) {
CHECK(i >= 0 && i <= 3) << "Error";
reshape->dims[axisMap[i]] = originDim[i];
}
}
if (MNN::OpType_ArgMax == op->type || MNN::OpType_ArgMin == op->type) {
auto param = op->main.AsArgMax();
auto originAxis = param->axis;
DCHECK(originAxis >= 0 && originAxis <= 3) << "ArgMax / Argmin axis ERROR!";
param->axis = axisMap[originAxis];
}
}
return true;
}