optimum/quanto/subpackage/commands/quantize.py (101 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Hugging Face models quantization command-line interface class.""" from typing import TYPE_CHECKING import torch from optimum.commands import BaseOptimumCLICommand from optimum.exporters import TasksManager from ...models import QuantizedTransformersModel if TYPE_CHECKING: from argparse import ArgumentParser SUPPORTED_LIBRARIES = ["transformers"] def parse_quantize_args(parser: "ArgumentParser"): required_group = parser.add_argument_group("Required arguments") required_group.add_argument( "output", type=str, help="The path to save the quantized model.", ) required_group.add_argument( "-m", "--model", type=str, required=True, help="Hugging Face Hub model id or path to a local model.", ) required_group.add_argument( "--weights", type=str, default="int8", choices=["int2", "int4", "int8", "float8"], help="The Hugging Face library to use to load the model.", ) optional_group = parser.add_argument_group("Optional arguments") optional_group.add_argument( "--revision", type=str, default=None, help="The Hugging Face model revision.", ) optional_group.add_argument( "--trust_remote_code", action="store_true", default=False, help="Trust remote code when loading the model.", ) optional_group.add_argument( "--library", type=str, default=None, choices=SUPPORTED_LIBRARIES, help="The Hugging Face library to use to load the model.", ) optional_group.add_argument( "--task", type=str, default=None, help="The model task (useful for models supporting multiple tasks).", ) optional_group.add_argument( "--torch_dtype", type=str, default="auto", choices=["auto", "fp16", "bf16"], help="The torch dtype to use when loading the model weights.", ) optional_group.add_argument( "--device", type=str, default="cpu", help="The device to use when loading the model.", ) class QuantizeCommand(BaseOptimumCLICommand): @staticmethod def parse_args(parser: "ArgumentParser"): return parse_quantize_args(parser) def run(self): model_name_or_path = self.args.model library_name = self.args.library if library_name is None: library_name = TasksManager.infer_library_from_model(model_name_or_path) if library_name not in SUPPORTED_LIBRARIES: raise ValueError( f"{library_name} models are not supported by this CLI, but can be quantized using the python API directly." ) task = self.args.task if task is None: task = TasksManager.infer_task_from_model(model_name_or_path) torch_dtype = self.args.torch_dtype if torch_dtype != "auto": torch_dtype = torch.float16 if self.args.torch_dtype == "fp16" else torch.bfloat16 model = TasksManager.get_model_from_task( task, model_name_or_path, revision=self.args.revision, trust_remote_code=self.args.trust_remote_code, framework="pt", torch_dtype=torch_dtype, device=torch.device(self.args.device), library_name=library_name, low_cpu_mem_usage=True, ) weights = f"q{self.args.weights}" qmodel = QuantizedTransformersModel.quantize(model, weights=weights) qmodel.save_pretrained(self.args.output)