void _ComputeLSTMOnnx()

in source/geometry/GeometryLSTM.cpp [62:476]


    void _ComputeLSTMOnnx(const std::vector<Tensor*>& inputs, const std::vector<Tensor*>& outputs, Context& context,
                          CommandBuffer& res, const LSTM* lstm, OpType type) const {
        /* inputs:
        X: T The input sequences packed (and potentially padded) into one 3-D tensor with the shape of [seq_length,
        batch_size, input_size].

        W: T
        The weight tensor for the gates. Concatenation of W[iofc] and WB[iofc] (if bidirectional) along dimension 0. The
        tensor has shape [num_directions, 4*hidden_size, input_size].

        R: T
        The recurrence weight tensor. Concatenation of R[iofc] and RB[iofc] (if bidirectional) along dimension 0. This
        tensor has shape [num_directions, 4*hidden_size, hidden_size].

        B: T (optional)
        The bias tensor for input gate. [Wb[iofc] + Rb[iofc]], and [WBb[iofc] + RBb[iofc]] (if bidirectional) along
        dimension 0. This tensor has shape [num_directions, 4*hidden_size]. Optional: If not specified - assumed to be
        0.
         */
        MNN_ASSERT(inputs.size() >= 4);
        auto X_Input      = inputs[0];
        auto W            = inputs[1];
        auto R            = inputs[2];
        auto B            = inputs[3];
        Tensor* O_Init    = nullptr;
        Tensor* Cell_Init = nullptr;
        if (inputs.size() >= 5) {
            O_Init = inputs[4];
        }
        if (inputs.size() >= 6) {
            Cell_Init = inputs[5];
        }

        /** Outputs:
         Y: T (optional)
         A tensor that concats all the intermediate output values of the hidden. It has shape [seq_length,
         num_directions, batch_size, hidden_size].

         Y_h: T (optional)
         The last output value of the hidden. It has shape [num_directions, batch_size, hidden_size].

         Y_c: T (optional)
         The last output value of the cell. It has shape [num_directions, batch_size, hidden_size].
         */
        auto Y = outputs[0];
        if (outputs.size() >= 2) {
            TensorUtils::getDescribe(outputs[1])->regions.clear();
            TensorUtils::getDescribe(outputs[1])->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
        }
        if (outputs.size() >= 3) {
            TensorUtils::getDescribe(outputs[2])->regions.clear();
            TensorUtils::getDescribe(outputs[2])->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
        }

        auto seqLength     = X_Input->length(0);
        auto inputSize     = X_Input->length(2);
        auto batchSize     = X_Input->length(1);
        auto hiddenSize    = Y->length(3);
        auto numDirections = Y->length(1);
        auto encode = [&](Tensor* X, int direction) {
            const int N = (type == OpType_RNN ? 1 : 4);
            // FirstPart: Gate = MatMul(X, W, B) :  N * hiddenSize, seqLength * batchSize
            std::shared_ptr<Tensor> Gate(Tensor::createDevice<float>({seqLength * batchSize, N * hiddenSize}, Tensor::CAFFE));
            res.extras.emplace_back(Gate);
            {
                auto h = N * hiddenSize;
                auto e = seqLength * batchSize;
                auto l = inputSize;
                std::unique_ptr<OpT> newop(new OpT);
                newop->type = OpType_While;
                newop->main.value = new LoopParamT;
                newop->main.type = OpParameter_LoopParam;
                auto loop = newop->main.AsLoopParam();
                loop->tensorNumber = 4;
                loop->inputIndexes = {0, 1, 2};
                loop->outputIndexes = {3};
                loop->loopNumber = 1;
                std::unique_ptr<RegionCommandT> rcmd(new RegionCommandT);
                rcmd->size = {e, l, h};
                rcmd->view.resize(4);
                rcmd->view[1].reset(new ViewT);
                rcmd->view[1]->offset = 0;
                rcmd->view[1]->stride = {l, 1, 0};
                // W
                rcmd->view[2].reset(new ViewT);
                rcmd->view[2]->offset = direction * N * hiddenSize * inputSize;
                rcmd->view[2]->stride = {0, 1, l};
                // Bias
                rcmd->view[3].reset(new ViewT);
                rcmd->view[3]->offset = direction * N * hiddenSize;
                rcmd->view[3]->stride = {0, 0, 1};

                // C
                rcmd->view[0].reset(new ViewT);
                rcmd->view[0]->offset = 0;
                rcmd->view[0]->stride = {h, 0, 1};

                rcmd->indexes = {3, 0, 1, 2};// C, A, B, Bias
                rcmd->steps = {0, 0, 0, 0};
                rcmd->iterIndexes = {-1, -1, -1, -1};
                rcmd->op.reset(new OpT);
                rcmd->op->type = OpType_MatMul;
                rcmd->op->main.type = OpParameter_MatMul;
                rcmd->op->main.value = new MatMulT;
                rcmd->op->main.AsMatMul()->transposeB = true;
                rcmd->op->main.AsMatMul()->transposeA = false;
                loop->commands.emplace_back(std::move(rcmd));
                flatbuffers::FlatBufferBuilder builder;
                builder.Finish(Op::Pack(builder, newop.get()));
                auto cmd = GeometryComputerUtils::makeCommand(builder, {X, W, B}, {Gate.get()});
                res.command.emplace_back(std::move(cmd));
            }

            // SecondPart: Compute outputs
            // Initial
            std::shared_ptr<Tensor> I(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            std::shared_ptr<Tensor> C(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            std::shared_ptr<Tensor> F(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            std::shared_ptr<Tensor> O(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            std::shared_ptr<Tensor> Cell(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            res.extras.insert(res.extras.end(), {I, C, F, O, Cell});
            // First Output
            const int I_Y = 0;
            const int I_Cell = 1;
            const int I_Gate = 3;
            const int I_I = 4;
            const int I_C = 5;
            const int I_F = 6;
            const int I_R = 7;
            const int I_HR = 8;
            const int I_Temp = 9;
            auto subEncoder = [&](int dstIndex, UnaryOpOperation unOp, int biOp, int offsetGate, int offsetHR, LoopParamT* loop) {
                // Binary
                {
                    std::unique_ptr<RegionCommandT> rcmd(new RegionCommandT);
                    rcmd->size = {1, batchSize, hiddenSize};
                    rcmd->indexes = {I_Temp, I_Gate, I_HR};
                    rcmd->iterIndexes = {-1, -1, -1};
                    rcmd->steps = {0, batchSize * hiddenSize * N, 0};
                    rcmd->view.resize(3);
                    rcmd->view[0].reset(new ViewT);
                    rcmd->view[0]->offset = 0;
                    rcmd->view[0]->stride = {hiddenSize * batchSize, hiddenSize, 1};
                    rcmd->view[1].reset(new ViewT);
                    rcmd->view[1]->offset = offsetGate;
                    rcmd->view[1]->stride = {N * hiddenSize * seqLength * batchSize, N * hiddenSize, 1};
                    rcmd->view[2].reset(new ViewT);
                    rcmd->view[2]->offset = offsetHR;
                    rcmd->view[2]->stride = {N * hiddenSize * batchSize, N * hiddenSize, 1};
                    rcmd->op.reset(new OpT);
                    rcmd->op->type = OpType_BinaryOp;
                    rcmd->op->main.type = OpParameter_BinaryOp;
                    rcmd->op->main.value = new BinaryOpT;
                    rcmd->op->main.AsBinaryOp()->opType = biOp;
                    loop->commands.emplace_back(std::move(rcmd));
                }
                // Unary
                {
                    std::unique_ptr<RegionCommandT> rcmd(new RegionCommandT);
                    rcmd->size = {1, 1, hiddenSize * batchSize};
                    rcmd->indexes = {dstIndex, I_Temp};
                    rcmd->iterIndexes = {-1, -1};
                    rcmd->steps = {0, 0};
                    rcmd->view.resize(2);
                    rcmd->view[1].reset(new ViewT);
                    rcmd->view[1]->offset = 0;
                    rcmd->view[1]->stride = {0, 0, 1};
                    rcmd->view[0].reset(new ViewT);
                    rcmd->view[0]->offset = 0;
                    rcmd->view[0]->stride = {0, 0, 1};
                    rcmd->op.reset(new OpT);
                    rcmd->op->type = OpType_UnaryOp;
                    rcmd->op->main.type = OpParameter_UnaryOp;
                    rcmd->op->main.value = new UnaryOpT;
                    rcmd->op->main.AsUnaryOp()->opType = unOp;
                    loop->commands.emplace_back(std::move(rcmd));
                }
            };
            std::shared_ptr<Tensor> HRTotal(Tensor::createDevice<float>({batchSize, N * hiddenSize}, Tensor::CAFFE));
            res.extras.emplace_back(HRTotal);
            std::shared_ptr<Tensor> Temp(Tensor::createDevice<float>({batchSize, hiddenSize}, Tensor::CAFFE));
            res.extras.emplace_back(Temp);

            auto sequenceEncode = [&](int start, int oInit, int cellInit, LoopParamT* loop) {
                int pos = start;
                int step = hiddenSize * batchSize * numDirections;
                if (direction) {
                    pos = seqLength - 1 - start;
                    step = -step;
                }
                int offset = hiddenSize * batchSize * pos * numDirections + direction * batchSize * hiddenSize;

                // Compute HR = MatMul(R, O)
                {
                    std::unique_ptr<RegionCommandT> rcmd(new RegionCommandT);
                    rcmd->size = {N * hiddenSize, hiddenSize, batchSize};
                    rcmd->indexes = {I_HR, I_R, oInit};
                    rcmd->iterIndexes = {-1, -1, -1};
                    rcmd->steps = {0, 0, step};
                    rcmd->op.reset(new OpT);
                    rcmd->op->type = OpType_MatMul;
                    rcmd->op->main.type = OpParameter_MatMul;
                    rcmd->op->main.value = new MatMulT;
                    rcmd->op->main.AsMatMul()->transposeB = true;
                    rcmd->op->main.AsMatMul()->transposeA = false;
                    rcmd->view.resize(3);
                    rcmd->view[0].reset(new ViewT);
                    rcmd->view[0]->offset = 0;
                    rcmd->view[0]->stride = {1, 0, N * hiddenSize};
                    rcmd->view[1].reset(new ViewT);
                    rcmd->view[1]->offset = direction * N * hiddenSize * hiddenSize;
                    rcmd->view[1]->stride = {batchSize, 1, 0};
                    rcmd->view[2].reset(new ViewT);
                    if (oInit != I_Y) {
                        rcmd->view[2]->offset = O->elementSize() * direction;
                    } else {
                        int pre = start - 1;
                        if (direction) {
                            pre = seqLength - 1 - pre;
                        }
                        rcmd->view[2]->offset = hiddenSize * batchSize * pre * numDirections + direction * batchSize * hiddenSize;
                    }
                    rcmd->view[2]->stride = {0, batchSize, 1};
                    loop->commands.emplace_back(std::move(rcmd));
                }

                if (type == OpType_RNN) {
                    subEncoder(I_Y, UnaryOpOperation_TANH, BinaryOpOperation_ADD, start * batchSize * hiddenSize, 0, loop);
                    loop->commands[loop->commands.size() - 1]->view[0]->offset = offset;
                    loop->commands[loop->commands.size() - 1]->steps[0] = step;
                    return;
                }
                // I = Sigmoid(WI * XI + BI + HRI)
                {
                    subEncoder(I_I, UnaryOpOperation_SIGMOID, BinaryOpOperation_ADD, start * batchSize * 4 * hiddenSize, 0, loop);
                }
                // C = tanh(WC * XC + BC + HRC)
                {
                    subEncoder(I_C, UnaryOpOperation_TANH, BinaryOpOperation_ADD, 3 * hiddenSize + start * batchSize * 4 * hiddenSize, 3 * hiddenSize, loop);
                }
                // F = Sigmoid(WF * XF + BF + HRF)
                {
                    subEncoder(I_F, UnaryOpOperation_SIGMOID, BinaryOpOperation_ADD, 2 * hiddenSize + start * batchSize * 4 * hiddenSize, 2 * hiddenSize, loop);
                }
                // Cell = I * C + F * Cell
                {
                    easyBinaryEncode(hiddenSize * batchSize, {I_Temp, I_I, I_C}, BinaryOpOperation_MUL, loop);
                    auto cellOffset = cellInit == I_Cell ? 0 : Cell->elementSize() * direction;
                    easyBinaryEncode(hiddenSize * batchSize, {I_I, I_F, cellInit}, BinaryOpOperation_MUL, loop, cellOffset);
                    easyBinaryEncode(hiddenSize * batchSize, {I_Cell, I_Temp, I_I}, BinaryOpOperation_ADD, loop);
                }
                // C = Sigmoid(WO * XO + BO + HRO)
                {
                    subEncoder(I_C, UnaryOpOperation_SIGMOID, BinaryOpOperation_ADD, 1 * hiddenSize + start * batchSize * 4 * hiddenSize, 1 * hiddenSize, loop);
                }
                // I = tanh(Cell), O = I * C
                {
                    easyUnaryEncode({I_I, I_Cell}, UnaryOpOperation_TANH, loop, hiddenSize * batchSize);
                    easyBinaryEncode(hiddenSize * batchSize, {I_Y, I_I, I_C}, BinaryOpOperation_MUL, loop, 0, step, offset);
                }
            };
            if (nullptr == O_Init && nullptr == Cell_Init) {
                std::unique_ptr<OpT> newop(new OpT);
                newop->type = OpType_While;
                newop->main.value = new LoopParamT;
                newop->main.type = OpParameter_LoopParam;
                auto loop = newop->main.AsLoopParam();
                // Y, Cell, O, Gate, I, C, F
                loop->tensorNumber = 7;
                loop->inputIndexes = {3};
                loop->outputIndexes = {0, 1, 2, 4, 5, 6};
                loop->loopNumber = 1;
                auto unaryGateEncode = [&](UnaryOpOperation unOp, int dstIndex, int index, LoopParamT* loop) {
                    std::unique_ptr<RegionCommandT> rcmd(new RegionCommandT);
                    rcmd->size = {1, batchSize, hiddenSize};
                    rcmd->indexes = {dstIndex, I_Gate};
                    rcmd->iterIndexes = {-1, -1};
                    rcmd->steps = {0, 0};
                    rcmd->view.resize(2);
                    rcmd->view[1].reset(new ViewT);
                    rcmd->view[1]->offset = index * hiddenSize;
                    rcmd->view[1]->stride = {N * hiddenSize * seqLength * batchSize, N * hiddenSize, 1};
                    rcmd->view[0].reset(new ViewT);
                    rcmd->view[0]->offset = 0;
                    rcmd->view[0]->stride = {hiddenSize * batchSize, hiddenSize, 1};
                    rcmd->op.reset(new OpT);
                    rcmd->op->type = OpType_UnaryOp;
                    rcmd->op->main.type = OpParameter_UnaryOp;
                    rcmd->op->main.value = new UnaryOpT;
                    rcmd->op->main.AsUnaryOp()->opType = unOp;
                    loop->commands.emplace_back(std::move(rcmd));
                };
                if (type == OpType_RNN) {
                    unaryGateEncode(UnaryOpOperation_TANH, I_Y, 0, loop);
                    loop->commands[loop->commands.size() - 1]->view[0]->offset = direction * (batchSize * hiddenSize) * (1 + (seqLength - 1) * numDirections);
                } else {
                    // I = Sigmoid(WI * XI + BI)
                    unaryGateEncode(UnaryOpOperation_SIGMOID, I_I, 0, loop);

                    // C = tanh(WC * XC + BC)
                    unaryGateEncode(UnaryOpOperation_TANH, I_C, 3, loop);

                    // Cell = I * C
                    easyBinaryEncode(hiddenSize * batchSize, {I_Cell, I_I, I_C}, BinaryOpOperation_MUL, loop);

                    // C = Sigmoid(WO * XO + BO)
                    unaryGateEncode(UnaryOpOperation_SIGMOID, I_C, 1, loop);

                    // I = tanh(Cell)
                    easyUnaryEncode({I_I, I_Cell}, UnaryOpOperation_TANH, loop, hiddenSize * batchSize);

                    // O = I * C
                    easyBinaryEncode(hiddenSize * batchSize, {I_Y, I_I, I_C}, BinaryOpOperation_MUL, loop, 0, 0, direction * ((batchSize * hiddenSize) + (seqLength - 1) * numDirections * batchSize * hiddenSize));
                }
                flatbuffers::FlatBufferBuilder builder;
                builder.Finish(Op::Pack(builder, newop.get()));
                auto cmd = GeometryComputerUtils::makeCommand(builder, {Gate.get()}, {Y, Cell.get(), O.get(), I.get(), C.get(), F.get()});
                res.command.emplace_back(std::move(cmd));
            } else {
                // Has Init O and Cell
                std::unique_ptr<OpT> newop(new OpT);
                newop->type = OpType_While;
                newop->main.value = new LoopParamT;
                newop->main.type = OpParameter_LoopParam;
                auto loop = newop->main.AsLoopParam();
                // Y, Cell, O, Gate, I, C, F, O_Init, Cell_Init
                const int I_OInit = 10;
                const int I_CellInit = 11;
                std::vector<Tensor*> inputs;
                if (type == OpType_RNN) { // only provide initial_h
                    loop->tensorNumber = 11;
                    loop->inputIndexes = {3, 7, 10};
                    inputs.assign({Gate.get(), R, O_Init});
                } else {
                    loop->tensorNumber = 12;
                    loop->inputIndexes = {3, 7, 10, 11};
                    inputs.assign({Gate.get(), R, O_Init, Cell_Init});
                }
                loop->outputIndexes = {0, 4, 5, 6, 8, 9, 2, 1};
                loop->loopNumber = 1;
                std::vector<Tensor*> suboutputs = {
                    Y, I.get(), C.get(), F.get(), HRTotal.get(), Temp.get(), O.get(), Cell.get()
                };
                sequenceEncode(0, I_OInit, I_CellInit, loop);
                flatbuffers::FlatBufferBuilder builder;
                builder.Finish(Op::Pack(builder, newop.get()));
                auto cmd = GeometryComputerUtils::makeCommand(builder, inputs, suboutputs);
                res.command.emplace_back(std::move(cmd));
            }
            // 1 - seqLength
            {
                std::unique_ptr<OpT> newop(new OpT);
                newop->type = OpType_While;
                newop->main.value = new LoopParamT;
                newop->main.type = OpParameter_LoopParam;
                auto loop = newop->main.AsLoopParam();
                loop->parallel = false;
                // Y, Cell, O, Gate, I, C, F, R, Temp
                loop->tensorNumber = 10;
                loop->inputIndexes = {3, 7, 2, 1};
                loop->outputIndexes = {0, 4, 5, 6, 8, 9};
                loop->loopNumber = seqLength - 1;
                std::vector<Tensor*> inputs = {
                    Gate.get(), R, O.get(), Cell.get()
                };
                std::vector<Tensor*> suboutputs = {
                    Y, I.get(), C.get(), F.get(), HRTotal.get(), Temp.get()
                };
                sequenceEncode(1, I_Y, I_Cell, loop);
                flatbuffers::FlatBufferBuilder builder;
                builder.Finish(Op::Pack(builder, newop.get()));
                auto cmd = GeometryComputerUtils::makeCommand(builder, inputs, suboutputs);
                res.command.emplace_back(std::move(cmd));
            }
            if (outputs.size() >= 2) {
                int pos = seqLength - 1;
                if (direction) {
                    pos = 0;
                }
                int offset = hiddenSize * batchSize * pos * numDirections + direction * batchSize * hiddenSize;

                TensorUtils::getDescribe(outputs[1])->regions.emplace_back(GeometryComputerUtils::makeRawAddressRef(Y, offset, O->elementSize(), O->elementSize() * direction));
            }
            if (outputs.size() >= 3) {
                TensorUtils::getDescribe(outputs[2])->regions.emplace_back(GeometryComputerUtils::makeRawAddressRef(Cell.get(), 0, Cell->elementSize(), Cell->elementSize() * direction));
            }
        };
        std::shared_ptr<Tensor> XWrap(Tensor::createDevice<float>({seqLength * batchSize, inputSize}, Tensor::CAFFE));
        GeometryComputerUtils::makeRawAddressRef(XWrap.get(), X_Input, 0, seqLength * batchSize * inputSize);
        res.extras.emplace_back(XWrap);
        encode(XWrap.get(), 0);
        if (numDirections > 1) {
            // Create Reverse X
            std::shared_ptr<Tensor> XReverse(Tensor::createDevice<float>({seqLength * batchSize, inputSize}, Tensor::CAFFE));
            res.extras.emplace_back(XReverse);
            auto des = TensorUtils::getDescribe(XReverse.get());
            des->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
            des->regions.resize(1);
            auto& reg = des->regions[0];
            reg.size[0] = 1;
            reg.size[1] = seqLength;
            reg.size[2] = batchSize * inputSize;
            reg.src.offset = batchSize * inputSize * (seqLength-1);
            reg.src.stride[0] = 0;
            reg.src.stride[1] = -(batchSize * inputSize);
            reg.src.stride[2] = 1;
            reg.dst.offset = 0;
            reg.dst.stride[0] = 0;
            reg.dst.stride[1] = batchSize * inputSize;
            reg.dst.stride[2] = 1;
            reg.origin = X_Input;
            // Encode XReverse
            encode(XReverse.get(), 1);
        }
    }