virtual bool onCompute()

in source/geometry/GeometryBatchMatMul.cpp [138:469]


    virtual bool onCompute(const Op* op, const std::vector<Tensor*>& inputs,
                                    const std::vector<Tensor*>& outputs, Context& context, CommandBuffer& res) const override {
        bool transposeA = false;
        bool transposeB = false;
        if (op->type() == OpType_BatchMatMul) {
            auto param = op->main_as_BatchMatMulParam();
            transposeA = param->adjX();
            transposeB = param->adjY();
        } else {
            auto param = op->main_as_MatMul();
            transposeA = param->transposeA();
            transposeB = param->transposeB();
        }

        auto input0          = inputs[0];
        auto input1          = inputs[1];
        Tensor* bias         = nullptr;
        auto output          = outputs[0];
        if (inputs.size() > 2) {
            bias = inputs[2];
        }
        auto outputDes = TensorUtils::getDescribe(output);
        // Fill output by zero if one of inputs is empty.
        if (input0->elementSize() == 0 || input1->elementSize() == 0) {
            outputDes->regions.clear();
            outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
            return true;
        }
        int outputNeedSqueeze = 0;
        bool eInsert = false;
        bool hInsert = false;
        if (input0->dimensions() < 2) {
            std::shared_ptr<Tensor> newTensor(new Tensor);
            TensorUtils::copyShape(input0, newTensor.get(), true);
            newTensor->buffer().type = input0->buffer().type;
            newTensor->buffer().dimensions = 2;
            newTensor->setLength(0, 1);
            newTensor->setLength(1, input0->length(0));
            TensorUtils::getDescribe(newTensor.get())->regions = {TensorUtils::makeFullSlice(input0)};
            TensorUtils::getDescribe(newTensor.get())->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
            input0 = newTensor.get();
            res.extras.emplace_back(newTensor);
            outputNeedSqueeze++;
            eInsert = true;
        }
        if (input1->dimensions() < 2) {
            std::shared_ptr<Tensor> newTensor(new Tensor);
            TensorUtils::copyShape(input1, newTensor.get(), true);
            newTensor->buffer().type = input1->buffer().type;
            newTensor->buffer().dimensions = 2;
            newTensor->setLength(0, input1->length(0));
            newTensor->setLength(1, 1);
            TensorUtils::getDescribe(newTensor.get())->regions = {TensorUtils::makeFullSlice(input1)};
            TensorUtils::getDescribe(newTensor.get())->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
            input1 = newTensor.get();
            res.extras.emplace_back(newTensor);
            outputNeedSqueeze++;
            hInsert = true;
        }
        int input0_end1 = input0->length(input0->dimensions()-2);
        int input0_end0 = input0->length(input0->dimensions()-1);
        int input1_end1 = input1->length(input1->dimensions()-2);
        int input1_end0 = input1->length(input1->dimensions()-1);
        int e = input0_end1;
        int l = input0_end0;
        int h = input1_end0;
        if (transposeA) {
            e = input0_end0;
            l = input0_end1;
        }
        if (transposeB) {
            h = input1_end1;
        }
        if (outputNeedSqueeze > 0) {
            std::shared_ptr<Tensor> newTensor(new Tensor);
            TensorUtils::copyShape(output, newTensor.get(), true);
            newTensor->buffer().dimensions = output->dimensions() + outputNeedSqueeze;
            newTensor->setLength(newTensor->dimensions() - 1, e);
            newTensor->setLength(newTensor->dimensions() - 2, h);
            newTensor->buffer().type = output->buffer().type;
            outputDes->regions = {TensorUtils::makeFullSlice(newTensor.get())};
            outputDes->memoryType = Tensor::InsideDescribe::MEMORY_VIRTUAL;
            res.extras.emplace_back(newTensor);

            output = newTensor.get();
            outputDes = TensorUtils::getDescribe(output);
        }

        if (output->dimensions() == 2) {
            // Use normal MatMul
            std::shared_ptr<Command> cmd(new Command);
            cmd->op      = op;
            if (bias == nullptr) {
                cmd->inputs  = {input0, input1};
            } else {
                cmd->inputs  = {input0, input1, bias};
            }
            cmd->outputs = {output};
            res.command.emplace_back(cmd);
            return true;
        }
        // Broadcast matmul don't support bias
        // Split MatMul
        outputDes->memoryType = Tensor::InsideDescribe::MEMORY_BACKEND;
        auto o0Dim = output->dimensions();
        // Compute BroastCast Dims
        auto dimOffset = o0Dim - 2;
        const int maxDimensions = dimOffset;
        int outputStrides[MNN_MAX_TENSOR_DIM];
        int input0Strides[MNN_MAX_TENSOR_DIM];
        int input1Strides[MNN_MAX_TENSOR_DIM];
        auto i0Offset = output->dimensions() - input0->dimensions();
        auto i1Offset = output->dimensions() - input1->dimensions();
        int totalSize = 1;
        int i0Size = 1;
        int i1Size = 1;
        for (int i = maxDimensions - 1; i >=0 ; --i) {
            outputStrides[i] = totalSize;
            input0Strides[i] = 0;
            input1Strides[i] = 0;
            totalSize *= output->length(i);
            if (i >= i0Offset && input0->length(i - i0Offset) > 1) {
                input0Strides[i] = i0Size;
                i0Size *= input0->length(i - i0Offset);
            }
            if (i >= i1Offset && input1->length(i - i1Offset) > 1) {
                input1Strides[i] = i1Size;
                i1Size *= input1->length(i - i1Offset);
            }
        }
        flatbuffers::FlatBufferBuilder builder;

        // Create Region Command
        std::vector<flatbuffers::Offset<View>> allViews(3);
        int size[] = {e, l, h};
        int steps[] = {e*h, e*l, l*h, 0};
        auto sizeOffset = builder.CreateVector(size, 3);
        {
            int stride[] = {h, 0, 1};
            auto strideOffset = builder.CreateVector(stride, 3);
            ViewBuilder viewB(builder);
            viewB.add_offset(0);
            viewB.add_stride(strideOffset);
            allViews[0] = viewB.Finish();
        }
        {
            int stride[3];
            stride[2] = 0;
            if (transposeA) {
                stride[0] = 1;
                stride[1] = e;
            } else {
                stride[1] = 1;
                stride[0] = l;
            }
            auto strideOffset = builder.CreateVector(stride, 3);
            ViewBuilder viewB(builder);
            viewB.add_offset(0);
            viewB.add_stride(strideOffset);
            allViews[1] = viewB.Finish();
        }
        {
            int stride[3];
            stride[0] = 0;
            if (transposeB) {
                stride[1] = 1;
                stride[2] = l;
            } else {
                stride[1] = h;
                stride[2] = 1;
            }
            auto strideOffset = builder.CreateVector(stride, 3);
            ViewBuilder viewB(builder);
            viewB.add_offset(0);
            viewB.add_stride(strideOffset);
            allViews[2] = viewB.Finish();
        }
        if (bias != nullptr) {
            int stride[3] = {0, 0, 1};
            auto strideOffset = builder.CreateVector(stride, 3);
            ViewBuilder viewB(builder);
            viewB.add_offset(0);
            viewB.add_stride(strideOffset);
            allViews.emplace_back(viewB.Finish());
        }
        flatbuffers::Offset<flatbuffers::String> nameOffset;
        if (nullptr != op->name()) {
            nameOffset = builder.CreateString(op->name()->c_str());
        }
        MatMulBuilder matMulParam(builder);
        matMulParam.add_transposeA(transposeA);
        matMulParam.add_transposeB(transposeB);
        auto matMulParamOffset = matMulParam.Finish();
        OpBuilder matMulOp(builder);
        matMulOp.add_type(OpType_MatMul);
        matMulOp.add_main(matMulParamOffset.Union());
        matMulOp.add_main_type(OpParameter_MatMul);
        auto opOffset = matMulOp.Finish();
        bool fastway = (i0Size == i1Size) || (i0Size == 1) || (i1Size == 1);
        if (fastway) {
            int inputNumber = 2;
            if (bias != nullptr) {
                inputNumber = 3;
            }
            if (1 == i0Size) {
                steps[1] = 0;
            }
            if (1 == i1Size) {
                steps[2] = 0;
            }
            int number = inputNumber + 1;
            auto viewOffset = builder.CreateVector<flatbuffers::Offset<View>>(allViews);
            int indexes[] = {2, 0, 1, 3};
            int iterIndexes[] = {-1, -1, -1, -1};
            auto indexOffset = builder.CreateVector(indexes, number);
            auto iterIndexesOffset = builder.CreateVector(iterIndexes, number);
            auto stepOffset = builder.CreateVector(steps, number);
            RegionCommandBuilder rgCmdBuilder(builder);
            rgCmdBuilder.add_op(opOffset);
            rgCmdBuilder.add_size(sizeOffset);
            rgCmdBuilder.add_view(viewOffset);
            rgCmdBuilder.add_iterIndexes(iterIndexesOffset);
            rgCmdBuilder.add_indexes(indexOffset);
            rgCmdBuilder.add_steps(stepOffset);
            auto regionCommandOffset = rgCmdBuilder.Finish();
            
            int inputIndexes[] = {0, 1, 3};
            auto inputIndexesOffset = builder.CreateVector(inputIndexes, inputNumber);
            int outputIndexes[] = {2};
            auto outputIndexOffset = builder.CreateVector(outputIndexes, 1);
            
            auto cmdOffset = builder.CreateVector(&regionCommandOffset, 1);
            LoopParamBuilder lpBuilder(builder);
            lpBuilder.add_commands(cmdOffset);
            lpBuilder.add_parallel(true);
            lpBuilder.add_inputIndexes(inputIndexesOffset);
            lpBuilder.add_outputIndexes(outputIndexOffset);
            lpBuilder.add_loopNumber(totalSize);
            lpBuilder.add_tensorNumber(number);
            auto lpOffset = lpBuilder.Finish();

            OpBuilder opBuilder(builder);
            opBuilder.add_main(lpOffset.Union());
            opBuilder.add_main_type(OpParameter_LoopParam);
            opBuilder.add_type(OpType_While);
            if (nullptr != op->name()) {
                opBuilder.add_name(nameOffset);
            }
            builder.Finish(opBuilder.Finish());
            if (bias != nullptr) {
                auto cmd = GeometryComputerUtils::makeCommand(builder, {input0, input1, bias}, {output});
                res.command.emplace_back(std::move(cmd));
            } else {
                auto cmd = GeometryComputerUtils::makeCommand(builder, {input0, input1}, {output});
                res.command.emplace_back(std::move(cmd));
            }
            return true;
        }
        auto i0OffsetTensor = context.allocConst(op, {totalSize}, halide_type_of<int>());
        auto i1OffsetTensor = context.allocConst(op, {totalSize}, halide_type_of<int>());
        if (nullptr == i0OffsetTensor || nullptr == i1OffsetTensor) {
            return false;
        }
        // Commpute Offset Index
        auto i0OffsetTensorPtr = i0OffsetTensor->host<int>();
        auto i1OffsetTensorPtr = i1OffsetTensor->host<int>();
        for (int index = 0; index < totalSize; ++index) {
            // Unrool the cords
            auto c = index;
            i0Offset = 0;
            i1Offset = 0;
            for (int i=0; i<maxDimensions; ++i) {
                auto cord = c / outputStrides[i];
                i0Offset += input0Strides[i] * cord;
                i1Offset += input1Strides[i] * cord;
                c = c % outputStrides[i];
            }
            i0OffsetTensorPtr[index] = i0Offset;
            i1OffsetTensorPtr[index] = i1Offset;
        }
        int inputNumber = 4;
        if (bias != nullptr) {
            inputNumber = 5;
        }
        int number = inputNumber + 1;
        int rgNumber = number - 2;
        auto viewOffset = builder.CreateVector<flatbuffers::Offset<View>>(allViews);
        int indexes[] = {4, 0, 1, 5};
        int iterIndexes[] = {-1, 2, 3, -1};
        auto indexOffset = builder.CreateVector(indexes, rgNumber);
        auto iterIndexesOffset = builder.CreateVector(iterIndexes, rgNumber);
        auto stepOffset = builder.CreateVector(steps, rgNumber);
        RegionCommandBuilder rgCmdBuilder(builder);
        rgCmdBuilder.add_op(opOffset);
        rgCmdBuilder.add_size(sizeOffset);
        rgCmdBuilder.add_view(viewOffset);
        rgCmdBuilder.add_iterIndexes(iterIndexesOffset);
        rgCmdBuilder.add_indexes(indexOffset);
        rgCmdBuilder.add_steps(stepOffset);
        auto regionCommandOffset = rgCmdBuilder.Finish();
        
        int inputIndexes[] = {0, 1, 2, 3, 5};
        auto inputIndexesOffset = builder.CreateVector(inputIndexes, inputNumber);
        int outputIndexes[] = {4};
        auto outputIndexOffset = builder.CreateVector(outputIndexes, 1);
        
        auto cmdOffset = builder.CreateVector(&regionCommandOffset, 1);
        LoopParamBuilder lpBuilder(builder);
        lpBuilder.add_commands(cmdOffset);
        lpBuilder.add_parallel(true);
        lpBuilder.add_inputIndexes(inputIndexesOffset);
        lpBuilder.add_outputIndexes(outputIndexOffset);
        lpBuilder.add_loopNumber(totalSize);
        lpBuilder.add_tensorNumber(number);
        auto lpOffset = lpBuilder.Finish();

        OpBuilder opBuilder(builder);
        opBuilder.add_main(lpOffset.Union());
        opBuilder.add_main_type(OpParameter_LoopParam);
        opBuilder.add_type(OpType_While);
        if (nullptr != op->name()) {
            opBuilder.add_name(nameOffset);
        }
        builder.Finish(opBuilder.Finish());
        std::vector<Tensor*> inputLoops{input0, input1, i0OffsetTensor.get(), i1OffsetTensor.get()};
        if (nullptr != bias) {
            inputLoops.emplace_back(bias);
        }
        auto cmd = GeometryComputerUtils::makeCommand(builder, inputLoops, {output});
        res.command.emplace_back(std::move(cmd));
        return true;
    }