optimum/commands/onnxruntime/optimize.py (72 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Optimization with ONNX Runtime command-line interface class.""" from pathlib import Path from typing import TYPE_CHECKING from optimum.commands.base import BaseOptimumCLICommand if TYPE_CHECKING: from argparse import ArgumentParser def parse_args_onnxruntime_optimize(parser: "ArgumentParser"): required_group = parser.add_argument_group("Required arguments") required_group.add_argument( "--onnx_model", type=Path, required=True, help="Path to the repository where the ONNX models to optimize are located.", ) required_group.add_argument( "-o", "--output", type=Path, required=True, help="Path to the directory where to store generated ONNX model.", ) level_group = parser.add_mutually_exclusive_group(required=True) level_group.add_argument( "-O1", action="store_true", help="Basic general optimizations (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).", ) level_group.add_argument( "-O2", action="store_true", help="Basic and extended general optimizations, transformers-specific fusions (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).", ) level_group.add_argument( "-O3", action="store_true", help="Same as O2 with Gelu approximation (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).", ) level_group.add_argument( "-O4", action="store_true", help="Same as O3 with mixed precision (see: https://huggingface.co/docs/optimum/onnxruntime/usage_guides/optimization for more details).", ) level_group.add_argument( "-c", "--config", type=Path, help="`ORTConfig` file to use to optimize the model.", ) class ONNXRuntimeOptimizeCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): return parse_args_onnxruntime_optimize(parser) def run(self): from ...onnxruntime.configuration import AutoOptimizationConfig, ORTConfig from ...onnxruntime.optimization import ORTOptimizer if self.args.output == self.args.onnx_model: raise ValueError("The output directory must be different than the directory hosting the ONNX model.") save_dir = self.args.output file_names = [model.name for model in self.args.onnx_model.glob("*.onnx")] optimizer = ORTOptimizer.from_pretrained(self.args.onnx_model, file_names) if self.args.config: optimization_config = ORTConfig.from_pretrained(self.args.config).optimization elif self.args.O1: optimization_config = AutoOptimizationConfig.O1() elif self.args.O2: optimization_config = AutoOptimizationConfig.O2() elif self.args.O3: optimization_config = AutoOptimizationConfig.O3() elif self.args.O4: optimization_config = AutoOptimizationConfig.O4() else: raise ValueError("Either -O1, -O2, -O3, -O4 or -c must be specified.") optimizer.optimize(save_dir=save_dir, optimization_config=optimization_config)