virtual EXPRP onExecute()

in tools/converter/source/optimizer/onnxextra/OnnxEinsum.cpp [27:399]


    virtual EXPRP onExecute(EXPRP expr) const override {
        auto inputs     = expr->inputs();
        auto op         = expr->get();
        auto extraParam = op->main_as_Extra();
        std::string equation;
        if (nullptr != extraParam->attr()) {
            const int attrSize = extraParam->attr()->size();
            for (int i = 0; i < attrSize; ++i) {
                auto attr       = extraParam->attr()->GetAs<Attribute>(i);
                const auto& key = attr->key()->str();
                if (key == "equation") {
                    equation = attr->s()->str();
                }
            }
        }
        if (equation.empty()) {
            MNN_ERROR("Can't convert Einsum for invalid Equation\n");
            return nullptr;
        }
        // Turn ... to .
        bool hasPrefix = false;
        {
            auto pos = equation.find("...");
            while (pos != std::string::npos) {
                equation = equation.replace(pos, 3, ".");
                pos = equation.find("...");
                hasPrefix = true;
            }
        }
        // Remove space
        std::vector<char> valid;
        for (int i=0; i<equation.size(); ++i) {
            if (equation[i] != ' ') {
                valid.emplace_back(equation[i]);
            }
        }
        valid.emplace_back('\0');
        equation = std::string(valid.data());
        auto pos = equation.find("->");
        if (pos == std::string::npos) {
            MNN_ERROR("Can't convert Einsum for no support Equation:%s\n", equation.c_str());
            return nullptr;
        }
        auto left = equation.substr(0, pos);
        auto right = equation.substr(pos+2, equation.size());
        if (expr->inputs().size() == 1 ){
            auto currentVar = expr->inputs()[0];
            std::map<char, int> outputPos;
            for (int i=0; i<right.size(); ++i) {
                outputPos.insert(std::make_pair(right[i], i));
            }
            std::vector<int> reduceAxis;
            std::map<char, int> inputPosRemap;
            int pos = 0;
            for (int i=0; i<left.size(); ++i) {
                if (outputPos.find(left[i]) == outputPos.end()) {
                    reduceAxis.emplace_back(i);
                    continue;
                }
                inputPosRemap.insert(std::make_pair(left[i], pos));
                pos++;
            }
            if (!reduceAxis.empty()) {
                currentVar = _ReduceSum(currentVar, reduceAxis, false);
            }
            std::vector<int> permuteDims;
            for (int i=0; i<right.size(); ++i) {
                permuteDims.emplace_back(inputPosRemap[right[i]]);
            }
            currentVar = _Permute(currentVar, permuteDims);
            currentVar->setName(expr->name());
            return currentVar->expr().first;
        }
        if (inputs.size() !=2 ) {
            MNN_ERROR("Can't convert Einsum for input size = %d\n", (int)inputs.size());
            return nullptr;
        }
        auto iPos = left.find(",");
        auto input0 = left.substr(0, iPos);
        auto input1 = left.substr(iPos+1, left.size());
        auto var0 = expr->inputs()[0];
        auto var1 = expr->inputs()[1];
        // dim = 4
        if (right.size() == 4 && input0.size() == 4 && input1.size() == 4) {
            // batch align:
            // bhwc,bhkc -> bhwk  batch = `bh`, reduce_dim = `c`

            // find reduce dim
            char reduce_dim;
            int reduce_dim_pos = -1;
            for (int i = 0; i < input0.size(); ++i) {
                auto c = input0[i];
                if (right.find(c) == std::string::npos) {
                    reduce_dim = c;
                    reduce_dim_pos = i;
                    break;
                }
            }
            bool needTransposeA = false;
            if (reduce_dim_pos >= 0 && input0.size() >= 2 && reduce_dim_pos == input0.size() - 2) {
                needTransposeA = true;
            }
            auto need_transpose = input1.find(reduce_dim) == (input1.size() - 1);
            // matmul: matmul auto broadcast such: `bhwc @ hkc` -> `bhwc @ bhkc`
            auto output = _MatMul(var0, var1, needTransposeA, need_transpose);
            output->setName(expr->name());
            return output->expr().first;
        }
        
        if(right.size() == 3) {
            // bid, bjd -> bij
            if(input0.size() == 3 && input1.size() == 3) {
                if(input0[0] == input1[0] && input0[0] == right[0]) {
                    if (input0[2] == input1[2]) {// bid, bjd
                        auto output = _MatMul(var0, var1, false, true);
                        output->setName(expr->name());
                        return output->expr().first;
                    } else if (input0[2] == input1[1]) {// bid, bdj
                        auto output = _MatMul(var0, var1, false, false);
                        output->setName(expr->name());
                        return output->expr().first;
                    } else if (input0[1] == input1[1]) {// bdi, bdj
                        auto output = _MatMul(var0, var1, true, false);
                        output->setName(expr->name());
                        return output->expr().first;
                    } else if (input0[1] == input1[2]) {// bdi, bjd
                        auto output = _MatMul(var0, var1, true, true);
                        output->setName(expr->name());
                        return output->expr().first;
                    }
                }
            }
        }
        auto aShape = _Shape(var0, NCHW);
        auto bShape = _Shape(var1, NCHW);
        VARP prefixshape;
        VARP prefixSize;
        auto preFixPostTreat = [&](VARP output) {
            if (right[0] != '.') {
                return output;
            }
            auto oShape = _Shape(output, NCHW);
            auto oRemainShape = _Slice(oShape, _Unsqueeze(_Scalar<int>(1), {0}), _Unsqueeze(_Rank(output) - _Scalar<int>(1), {0}));
            auto oPostShape = _Concat({prefixshape, oRemainShape}, 0);
            return _ReshapeF(output, oPostShape, MNN::MNN_DATA_FORMAT_NCHW);
        };
        if (hasPrefix) {
            // Seperate prefix shape
            if (input0[0] == '.') {
                auto remainA = _Scalar<int>((int)input0.size()-1);
                auto rankA = _Rank(var0);
                prefixSize = rankA - remainA;
                auto aShapeRemain = _Slice(aShape, _Unsqueeze(prefixSize, {0}), _Unsqueeze(remainA, {0}));
                prefixshape = _Slice(aShape, _Unsqueeze(_Scalar<int>(0), {0}), _Unsqueeze(rankA - remainA, {0}));
                auto newAShape = _Concat({_Unsqueeze(_Scalar<int>(-1), {0}), aShapeRemain}, 0);
                var0 = _ReshapeF(var0, newAShape, MNN::MNN_DATA_FORMAT_NCHW);
                aShape = _Shape(var0, NCHW);
            }
            if (input1[0] == '.') {
                auto rankB = _Rank(var1);
                auto remainB = _Scalar<int>((int)input1.size()-1);
                auto bShapeRemain = _Slice(bShape, _Unsqueeze(prefixSize, {0}), _Unsqueeze(remainB, {0}));
                if (nullptr == prefixshape) {
                    prefixshape = _Slice(bShape, _Unsqueeze(_Scalar<int>(0), {0}), _Unsqueeze(rankB - remainB, {0}));
                }
                auto newBShape = _Concat({_Unsqueeze(_Scalar<int>(-1), {0}), bShapeRemain}, 0);
                var1 = _ReshapeF(var1, newBShape, MNN::MNN_DATA_FORMAT_NCHW);
                bShape = _Shape(var1, NCHW);
            }
        }
        std::map<char, int> input0Pos;
        for (int i=0; i<input0.size(); ++i) {
            input0Pos.insert(std::make_pair(input0[i], i));
        }
        std::map<char, int> input1Pos;
        for (int i=0; i<input1.size(); ++i) {
            input1Pos.insert(std::make_pair(input1[i], i));
        }
        std::map<char, int> outputPos;
        std::vector<char> sumPos;
        std::vector<char> bothPos;
        std::vector<char> aPos;
        std::vector<char> bPos;
        for (int i=0; i<right.size(); ++i) {
            auto c = right[i];
            outputPos.insert(std::make_pair(c, i));
            bool i0Find = input0Pos.find(c) != input0Pos.end();
            bool i1Find = input1Pos.find(c) != input1Pos.end();
            if (i0Find && i1Find) {
                bothPos.emplace_back(c);
                continue;
            }
            if ((!i0Find) && i1Find) {
                bPos.emplace_back(c);
                continue;
            }
            if (i0Find && (!i1Find)) {
                aPos.emplace_back(c);
                continue;
            }
            MNN_ASSERT(false);
        }
        
        for (int i=0; i<input0.size(); ++i) {
            if (outputPos.find(input0[i]) == outputPos.end()) {
                sumPos.emplace_back(input0[i]);
            }
        }
        // dim < 4
        if (sumPos.empty()) {
            // Broadcast Mul
            {
                // Reshape + Transpose
                std::vector<int> reshapeDims(outputPos.size(), 0);
                int insertPos = (int)input0Pos.size();
                std::vector<int> transpose;
                for (int i=0; i<right.size(); ++i) {
                    auto iter = input0Pos.find(right[i]);
                    if (iter == input0Pos.end()) {
                        reshapeDims[insertPos] = 1;
                        transpose.emplace_back(insertPos);
                        insertPos++;
                    } else {
                        transpose.emplace_back(iter->second);
                    }
                }
                auto _shape  = _Const(reshapeDims.data(), {static_cast<int32_t>(right.size())}, NHWC, halide_type_of<int>());
                var0 = _ReshapeF(var0, _shape, MNN::MNN_DATA_FORMAT_NCHW);
                var0 = _Permute(var0, transpose);
            }
            {
                // Reshape + Transpose
                std::vector<int> reshapeDims(outputPos.size(), 0);
                int insertPos = (int)input1Pos.size();
                std::vector<int> transpose;
                for (int i=0; i<right.size(); ++i) {
                    auto iter = input1Pos.find(right[i]);
                    if (iter == input1Pos.end()) {
                        reshapeDims[insertPos] = 1;
                        transpose.emplace_back(insertPos);
                        insertPos++;
                    } else {
                        transpose.emplace_back(iter->second);
                    }
                }
                auto _shape  = _Const(reshapeDims.data(), {static_cast<int>(right.size())}, NHWC, halide_type_of<int>());
                var1 = _ReshapeF(var1, _shape, MNN::MNN_DATA_FORMAT_NCHW);
                var1 = _Permute(var1, transpose);
            }
            auto output = var0 * var1;
            if (hasPrefix) {
                output = preFixPostTreat(output);
            }
            output->setName(expr->name());
            return output->expr().first;
        }
        auto one = _Unsqueeze(_Scalar<int>(1), {0});

        // MatMul
        // Remove sum pos from aPos and bPos
        std::vector<char> tempA;
        for (int i=0; i<aPos.size(); ++i) {
            bool find = false;
            for (int j=0; j<sumPos.size(); ++j) {
                if (sumPos[j] == aPos[i]) {
                    find = true;
                    break;
                }
            }
            if (!find) {
                tempA.emplace_back(aPos[i]);
            }
        }
        aPos = tempA;
        std::vector<char> tempB;
        for (int i=0; i<bPos.size(); ++i) {
            bool find = false;
            for (int j=0; j<sumPos.size(); ++j) {
                if (sumPos[j] == bPos[i]) {
                    find = true;
                    break;
                }
            }
            if (!find) {
                tempB.emplace_back(bPos[i]);
            }
        }
        bPos = tempB;
        // outside and sum is common for A and B
        VARP outsideLength = _Unsqueeze(_Scalar<int>(1), {0});
        for (int i=0; i<bothPos.size(); ++i) {
            auto size0 = _Slice(aShape, _Unsqueeze(_Scalar<int>(input0Pos[bothPos[i]]), {0}), one);
            auto size1 = _Slice(bShape, _Unsqueeze(_Scalar<int>(input1Pos[bothPos[i]]), {0}), one);
            auto bothsize = size0;
            outsideLength = outsideLength * bothsize;
        }
        
        VARP sumLength = _Unsqueeze(_Scalar<int>(1), {0});
        for (int i=0; i<sumPos.size(); ++i) {
            sumLength = sumLength * _Slice(aShape, _Unsqueeze(_Scalar<int>(input0Pos[sumPos[i]]), {0}), one);
        }
        {
            // Transpose and reshape as 3 dimension
            // AB -> A -> sum
            std::vector<int> transpose;
            for (int i=0; i<bothPos.size(); ++i) {
                transpose.emplace_back(input0Pos[bothPos[i]]);
            }
            VARP ALength = _Unsqueeze(_Scalar<int>(1), {0});
            for (int i=0; i<aPos.size(); ++i) {
                transpose.emplace_back(input0Pos[aPos[i]]);
                ALength = ALength * _Slice(aShape, _Unsqueeze(_Scalar<int>(input0Pos[aPos[i]]), {0}), one);
            }
            for (int i=0; i<sumPos.size(); ++i) {
                transpose.emplace_back(input0Pos[sumPos[i]]);
            }
            var0 = _Permute(var0, transpose);
            var0 = _ReshapeF(var0, _Concat({outsideLength, _Unsqueeze(_Scalar<int>(-1), {0}), sumLength}, 0), MNN::MNN_DATA_FORMAT_NCHW);
        }
        {
            // Transpose
            // AB -> B -> sum
            std::vector<int> transpose;
            for (int i=0; i<bothPos.size(); ++i) {
                transpose.emplace_back(input1Pos[bothPos[i]]);
            }
            VARP BLength = _Unsqueeze(_Scalar<int>(1), {0});
            for (int i=0; i<bPos.size(); ++i) {
                transpose.emplace_back(input1Pos[bPos[i]]);
                BLength = BLength * _Slice(bShape, _Unsqueeze(_Scalar<int>(input1Pos[bPos[i]]), {0}), one);
            }
            for (int i=0; i<sumPos.size(); ++i) {
                transpose.emplace_back(input1Pos[sumPos[i]]);
            }
            var1 = _Permute(var1, transpose);
            var1 = _ReshapeF(var1, _Concat({outsideLength, _Unsqueeze(_Scalar<int>(-1), {0}), sumLength}, 0), MNN::MNN_DATA_FORMAT_NCHW);
        }
        auto output = _MatMul(var0, var1, false, true);
        std::vector<VARP> cShapeGroup;

        // Permute output if needed, origin dimension pos is AB - A - B
        std::map<char, int> originOutputPos;
        for (int i=0; i<bothPos.size(); ++i) {
            originOutputPos.insert(std::make_pair(bothPos[i], i));
            cShapeGroup.emplace_back(_Slice(aShape, _Unsqueeze(_Scalar<int>(input0Pos[bothPos[i]]), {0}), one));
        }
        for (int i=0; i<aPos.size(); ++i) {
            originOutputPos.insert(std::make_pair(aPos[i], i + bothPos.size()));
            cShapeGroup.emplace_back(_Slice(aShape, _Unsqueeze(_Scalar<int>(input0Pos[aPos[i]]), {0}), one));
        }
        for (int i=0; i<bPos.size(); ++i) {
            originOutputPos.insert(std::make_pair(bPos[i], i + bothPos.size() + aPos.size()));
            cShapeGroup.emplace_back(_Slice(bShape, _Unsqueeze(_Scalar<int>(input1Pos[bPos[i]]), {0}), one));
        }
        auto cShape = _Concat(cShapeGroup, 0);
        output = _ReshapeF(output, cShape, MNN::MNN_DATA_FORMAT_NCHW);
        bool needPermute = false;
        std::vector<int> transpose(right.size());
        for (int i=0; i<right.size(); ++i) {
            transpose[i] = originOutputPos[right[i]];
            if (transpose[i] != i) {
                needPermute = true;
            }
        }
        if (needPermute) {
            output = _Permute(output, transpose);
        }
        if (hasPrefix) {
            output = preFixPostTreat(output);
        }
        output->setName(expr->name());
        return output->expr().first;
    }