in optimum/quanto/subpackage/commands/quantize.py [0:0]
def run(self):
model_name_or_path = self.args.model
library_name = self.args.library
if library_name is None:
library_name = TasksManager.infer_library_from_model(model_name_or_path)
if library_name not in SUPPORTED_LIBRARIES:
raise ValueError(
f"{library_name} models are not supported by this CLI, but can be quantized using the python API directly."
)
task = self.args.task
if task is None:
task = TasksManager.infer_task_from_model(model_name_or_path)
torch_dtype = self.args.torch_dtype
if torch_dtype != "auto":
torch_dtype = torch.float16 if self.args.torch_dtype == "fp16" else torch.bfloat16
model = TasksManager.get_model_from_task(
task,
model_name_or_path,
revision=self.args.revision,
trust_remote_code=self.args.trust_remote_code,
framework="pt",
torch_dtype=torch_dtype,
device=torch.device(self.args.device),
library_name=library_name,
low_cpu_mem_usage=True,
)
weights = f"q{self.args.weights}"
qmodel = QuantizedTransformersModel.quantize(model, weights=weights)
qmodel.save_pretrained(self.args.output)