optimum_benchmark/backends/pytorch/backend.py (285 lines of code) (raw):

from collections import OrderedDict from tempfile import TemporaryDirectory from typing import Any, Callable, Dict, List import torch from accelerate import Accelerator # from accelerate.utils import compile_regions from datasets import Dataset from transformers import Trainer, TrainerCallback, TrainerState, TrainingArguments from transformers.quantizers import AutoQuantizationConfig from ...import_utils import ( is_deepspeed_available, is_gptqmodel_available, is_torch_distributed_available, is_zentorch_available, ) from ..base import Backend from ..peft_utils import apply_peft from ..transformers_utils import fast_weights_init from .config import PyTorchConfig if is_deepspeed_available(): import deepspeed # type: ignore if is_torch_distributed_available(): import torch.distributed # type: ignore if is_zentorch_available(): import zentorch # type: ignore # noqa: F401 if is_gptqmodel_available(): import enum if not hasattr(enum, "EnumType") and hasattr(enum, "EnumMeta"): # This is a workaround for a bug in gptqmodel where it tries to access EnumType # from the enum module, but it is not available in Python 3.10 and below. enum.EnumType = enum.EnumMeta class PyTorchBackend(Backend[PyTorchConfig]): NAME = "pytorch" def __init__(self, config: PyTorchConfig): super().__init__(config) # Threads if self.config.inter_op_num_threads is not None: self.logger.info(f"\t+ Setting pytorch inter_op_num_threads({self.config.inter_op_num_threads}))") torch.set_num_threads(self.config.inter_op_num_threads) if self.config.intra_op_num_threads is not None: self.logger.info(f"\t+ Setting pytorch intra_op_num_threads({self.config.intra_op_num_threads}))") torch.set_num_interop_threads(self.config.intra_op_num_threads) # TF32 if self.config.allow_tf32: self.logger.info("\t+ Enabling TF32") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True # Autocast if self.config.autocast_enabled: self.logger.info("\t+ Enabling automatic mixed precision") torch.set_autocast_enabled(True) if self.config.autocast_dtype is not None: if self.config.device == "cpu": self.logger.info(f"\t+ Setting autocast cpu dtype to {self.config.autocast_dtype}") torch.set_autocast_cpu_dtype(getattr(torch, self.config.autocast_dtype)) elif self.config.device == "cuda": self.logger.info(f"\t+ Setting autocast gpu dtype to {self.config.autocast_dtype}") torch.set_autocast_gpu_dtype(getattr(torch, self.config.autocast_dtype)) else: raise ValueError(f"Device {self.config.device} not supported for autocast") def load(self) -> None: self.logger.info("\t+ Creating backend temporary directory") self.tmpdir = TemporaryDirectory() if self.config.library == "transformers": self.load_transformers_model() elif self.config.library == "diffusers": self.load_diffusers_model() elif self.config.library == "timm": self.load_timm_model() else: raise ValueError(f"Library {self.config.library} not supported for PyTorch backend") self.logger.info("\t+ Cleaning up backend temporary directory") self.tmpdir.cleanup() def load_transformers_model_from_pretrained(self) -> None: self.logger.info("\t+ Loading Transformers model") self.pretrained_model = self.automodel_loader.from_pretrained( pretrained_model_name_or_path=self.config.model, **self.config.model_kwargs, **self.automodel_kwargs, ) if self.config.device_map is None and self.config.device != "cpu": self.logger.info(f"\t+ Moving Transformers model to device: {self.config.device}") self.pretrained_model = self.pretrained_model.to(self.config.device) def load_transformers_model_with_no_weights(self) -> None: with fast_weights_init(): original_model, self.config.model = self.config.model, self.no_weights_model_path.as_posix() self.load_transformers_model_from_pretrained() self.config.model = original_model def load_transformers_model(self): if self.config.deepspeed_inference and self.is_quantized: raise ValueError("Deepspeed-Inference is not compatible with Transformers quantization") # Quantization if self.is_quantized: self.logger.info("\t+ Processing AutoQuantization config") self.quantization_config = AutoQuantizationConfig.from_dict( dict( getattr(self.pretrained_config, "quantization_config", {}), **self.config.quantization_config, ) ) # Model loading if self.config.no_weights: self.logger.info("\t+ Creating no weights model") if self.config.tp_plan is not None: self.create_no_weights_model_slow() else: self.create_no_weights_model_fast() self.logger.info("\t+ Loading model with random weights") self.load_transformers_model_with_no_weights() else: self.logger.info("\t+ Loading model with pretrained weights") self.load_transformers_model_from_pretrained() # KV-Cache if self.config.cache_implementation is not None: self.logger.info(f"\t+ Setting cache implementation to {self.config.cache_implementation}") self.pretrained_model.generation_config.cache_implementation = self.config.cache_implementation # BetterTransformer if self.config.to_bettertransformer: self.logger.info("\t+ To BetterTransformer") self.pretrained_model.to_bettertransformer() # Eval mode if self.config.eval_mode: self.logger.info("\t+ Enabling eval mode") self.pretrained_model.eval() # PEFT if self.config.peft_type is not None: self.logger.info("\t+ Applying PEFT") self.pretrained_model = apply_peft(self.pretrained_model, self.config.peft_type, self.config.peft_config) # DeepSpeed if self.config.deepspeed_inference: self.logger.info("\t+ Initializing DeepSpeed Inference Engine") self.pretrained_model = deepspeed.init_inference( model=self.pretrained_model, config=self.config.deepspeed_inference_config ) # Torch compile if self.config.torch_compile: if self.config.torch_compile_target == "model": self.logger.info("\t+ Using torch.compile on model") self.pretrained_model = torch.compile(self.pretrained_model, **self.config.torch_compile_config) # elif self.config.torch_compile_target == "regions": # self.logger.info("\t+ Using torch.compile on regions") # self.pretrained_model = compile_regions(self.pretrained_model, **self.config.torch_compile_config) elif self.config.torch_compile_target == "forward": self.logger.info("\t+ Using torch.compile on forward") self.pretrained_model.forward = torch.compile( self.pretrained_model.forward, **self.config.torch_compile_config ) else: raise ValueError(f"Target {self.config.torch_compile_target} not supported") def load_diffusers_pipeline_from_pretrained(self) -> None: self.pretrained_model = self.automodel_loader.from_pretrained( self.config.model, # pretrained_model_name_or_path=self.config.model, # pretrained_model_or_path=self.config.model, device_map=self.config.device_map, **self.config.model_kwargs, **self.automodel_kwargs, ) if self.config.device_map is None and self.config.device != "cpu": self.logger.info(f"\t+ Moving Diffusion Pipeline to device: {self.config.device}") self.pretrained_model = self.pretrained_model.to(self.config.device) def load_diffusers_model(self): self.logger.info("\t+ Loading Diffusion Pipeline") self.logger.info(f"\t+ Using Diffusers Pipeline {self.automodel_loader.__name__}") # Model loading if self.config.no_weights: raise ValueError("No weights model not supported for Diffusers") else: self.load_diffusers_pipeline_from_pretrained() # Torch compile if self.config.torch_compile: self.logger.info("\t+ Using torch.compile on unet and vae") self.pretrained_model.unet = torch.compile(self.pretrained_model.unet, **self.config.torch_compile_config) self.pretrained_model.vae.decode = torch.compile( self.pretrained_model.vae.decode, **self.config.torch_compile_config ) def load_timm_model_form_pretrained(self) -> None: self.pretrained_model = self.automodel_loader(model_name=self.config.model) if self.config.device != "cpu": self.logger.info(f"\t+ Moving Timm model to device: {self.config.device}") self.pretrained_model = self.pretrained_model.to(self.config.device) def load_timm_model(self): self.logger.info("\t+ Loading Timm model") self.logger.info(f"\t+ Using Timm's {self.automodel_loader.__name__}") # Model loading if self.config.no_weights: raise ValueError("No weights model not supported for Timm") else: self.load_timm_model_form_pretrained() # Torch compile if self.config.torch_compile: if self.config.torch_compile_target == "forward": self.logger.info("\t+ Using torch.compile on forward") self.pretrained_model.forward = torch.compile( self.pretrained_model.forward, **self.config.torch_compile_config ) elif self.config.torch_compile_target == "model": self.logger.info("\t+ Using torch.compile on model") self.pretrained_model = torch.compile(self.pretrained_model, **self.config.torch_compile_config) else: raise ValueError(f"Target {self.config.torch_compile_target} not supported") @property def is_quantized(self) -> bool: return self.config.quantization_scheme is not None or ( hasattr(self.pretrained_config, "quantization_config") and self.pretrained_config.quantization_config.get("quant_method") is not None ) @property def is_gptq_quantized(self) -> bool: return self.config.quantization_scheme == "gptq" or ( hasattr(self.pretrained_config, "quantization_config") and self.pretrained_config.quantization_config.get("quant_method") == "gptq" ) @property def is_bnb_quantized(self) -> bool: return self.config.quantization_scheme == "bnb" or ( hasattr(self.pretrained_config, "quantization_config") and self.pretrained_config.quantization_config.get("quant_method") == "bnb" ) @property def is_exllamav2(self) -> bool: return ( self.is_quantized and (self.is_gptq_quantized) and ( ( hasattr(self.pretrained_config, "quantization_config") and hasattr(self.pretrained_config.quantization_config, "exllama_config") and self.pretrained_config.quantization_config.exllama_config.get("version") == 2 ) or ( "exllama_config" in self.config.quantization_config and self.config.quantization_config["exllama_config"].get("version") == 2 ) ) ) @property def automodel_kwargs(self) -> Dict[str, Any]: kwargs = {} if self.config.torch_dtype is not None: if hasattr(torch, self.config.torch_dtype): kwargs["torch_dtype"] = getattr(torch, self.config.torch_dtype) else: kwargs["torch_dtype"] = self.config.torch_dtype if self.is_quantized: kwargs["quantization_config"] = self.quantization_config if self.config.attn_implementation is not None: kwargs["attn_implementation"] = self.config.attn_implementation if self.config.low_cpu_mem_usage is not None: kwargs["low_cpu_mem_usage"] = self.config.low_cpu_mem_usage if self.config.device_map is not None: kwargs["device_map"] = self.config.device_map if self.config.tp_plan is not None: kwargs["tp_plan"] = self.config.tp_plan return kwargs @property def split_between_processes(self) -> bool: return ( is_torch_distributed_available() and torch.distributed.is_initialized() and not self.config.deepspeed_inference ) def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: if self.split_between_processes: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs for key, value in inputs.items(): if isinstance(value, torch.Tensor): inputs[key] = value.to(self.config.device) if self.config.library == "timm": inputs = {"x": inputs["pixel_values"]} return inputs @torch.inference_mode() def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.forward(**inputs, **kwargs) @torch.inference_mode() def prefill(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: assert kwargs.get("max_new_tokens") == kwargs.get("min_new_tokens") == 1, ( "For prefilling, max_new_tokens and min_new_tokens must be equal to 1" ) return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model.generate(**inputs, **kwargs) @torch.inference_mode() def call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: return self.pretrained_model(**inputs, **kwargs) def train( self, training_dataset: Dataset, training_arguments: Dict[str, Any], training_callbacks: List[TrainerCallback], training_data_collator: Callable[[List[Dict[str, Any]]], Dict[str, Any]], ) -> TrainerState: self.logger.info(f"\t+ Wrapping training arguments with {TrainingArguments.__name__}") training_arguments["use_cpu"] = self.config.device == "cpu" training_arguments = TrainingArguments(**training_arguments) self.logger.info(f"\t+ Wrapping model with {Trainer.__name__}") trainer = Trainer( args=training_arguments, model=self.pretrained_model, callbacks=training_callbacks, train_dataset=training_dataset, data_collator=training_data_collator, ) self.logger.info("\t+ Starting training") trainer.train() self.logger.info("\t+ Finished training")