mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp (48 lines of code) (raw):

//===- TosaToLinalgPass.cpp - Lowering Tosa to Linalg Dialect -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This transformation pass legalizes Tosa operations to the Linalg dialect. // //===----------------------------------------------------------------------===// #include "../PassDetail.h" #include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Tosa/Transforms/PassDetail.h" #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; namespace { struct TosaToLinalgNamed : public TosaToLinalgNamedBase<TosaToLinalgNamed> { public: void getDependentDialects(DialectRegistry &registry) const override { registry.insert<arith::ArithmeticDialect, linalg::LinalgDialect, math::MathDialect, StandardOpsDialect, tensor::TensorDialect, scf::SCFDialect>(); } void runOnFunction() override { RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect<linalg::LinalgDialect, StandardOpsDialect, tosa::TosaDialect, tensor::TensorDialect, scf::SCFDialect>(); // Not every TOSA op can be legalized to linalg. target.addIllegalOp<tosa::Conv2DOp>(); target.addIllegalOp<tosa::DepthwiseConv2DOp>(); target.addIllegalOp<tosa::MaxPool2dOp>(); target.addIllegalOp<tosa::AvgPool2dOp>(); target.addIllegalOp<tosa::MatMulOp>(); target.addIllegalOp<tosa::FullyConnectedOp>(); target.markUnknownOpDynamicallyLegal([](Operation *) { return true; }); FuncOp func = getFunction(); mlir::tosa::populateTosaToLinalgNamedConversionPatterns(&patterns); if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); } }; } // namespace std::unique_ptr<Pass> mlir::tosa::createTosaToLinalgNamed() { return std::make_unique<TosaToLinalgNamed>(); }