community-content/vertex_model_garden/model_oss/peft/handler.py (309 lines of code) (raw):

"""Custom handler for huggingface/peft models.""" # pylint: disable=g-importing-member # pylint: disable=logging-fstring-interpolation import os import time from typing import Any, List, Tuple from absl import logging from awq import AutoAWQForCausalLM from diffusers import DPMSolverMultistepScheduler from diffusers import StableDiffusionPipeline from peft import PeftModel from PIL import Image import psutil import torch import transformers from transformers import AutoModelForCausalLM from transformers import AutoModelForSequenceClassification from transformers import AutoTokenizer from transformers import BitsAndBytesConfig from ts.torch_handler.base_handler import BaseHandler from util import constants from util import fileutils from util import image_format_converter if os.path.exists(constants.SHARED_MEM_DIR): logging.info( "SharedMemorySizeMb: %s", psutil.disk_usage(constants.SHARED_MEM_DIR).free / 1e6, ) # Tasks TEXT_TO_IMAGE_LORA = "text-to-image-lora" SEQUENCE_CLASSIFICATION_LORA = "sequence-classification-lora" CAUSAL_LANGUAGE_MODELING_LORA = "causal-language-modeling-lora" INSTRUCT_LORA = "instruct-lora" # Inference parameters. _NUM_INFERENCE_STEPS = 25 _MAX_LENGTH_DEFAULT = 200 _MAX_TOKENS_DEFAULT = None _TEMPERATURE_DEFAULT = 1.0 _TOP_P_DEFAULT = 1.0 _TOP_K_DEFAULT = 10 logging.set_verbosity(os.environ.get("LOG_LEVEL", logging.INFO)) class PeftHandler(BaseHandler): """Custom handler for Peft models.""" def initialize(self, context: Any): """Initializes the handler.""" logging.info("Start to initialize the PEFT handler.") properties = context.system_properties self.map_location = ( "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" ) self.device = torch.device( self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else self.map_location ) self.manifest = context.manifest self.precision_mode = os.environ.get( "PRECISION_LOADING_MODE", constants.PRECISION_MODE_16 ) self.task = os.environ.get("TASK", CAUSAL_LANGUAGE_MODELING_LORA) trust_remote_code = os.environ.get("TRUST_REMOTE_CODE", None) if trust_remote_code == "false": self.trust_remote_code = False else: self.trust_remote_code = True # If present, the path of the model in the container. aip_storage_dir = os.environ.get("AIP_STORAGE_DIR", None) # If present, the URI of the model in a google owned GCS bucket. aip_storage_uri = os.environ.get("AIP_STORAGE_URI", None) model_id = os.environ.get("MODEL_ID", None) base_model_id = os.environ.get("BASE_MODEL_ID", None) self.model_id = None if aip_storage_dir: self.model_id = aip_storage_dir logging.info(f"Loaded base model from AIP_STORAGE_DIR: {self.model_id}.") elif aip_storage_uri: self.model_id = aip_storage_uri logging.info(f"Loaded base model from AIP_STORAGE_URI: {self.model_id}.") elif model_id: self.model_id = model_id logging.info(f"Loaded base model from MODEL_ID: {self.model_id}.") elif base_model_id: # Note: BASE_MODEL_ID has been unified with MODEL_ID. # MODEL_ID should be used whenever possible. self.model_id = base_model_id logging.info(f"Loaded base model from BASE_MODEL_ID: {self.model_id}.") self.quantization = os.environ.get("QUANTIZATION", None) if not self.model_id: raise ValueError("Base model id is must be set.") if fileutils.is_gcs_path(self.model_id): fileutils.download_gcs_dir_to_local( self.model_id, constants.LOCAL_BASE_MODEL_DIR, skip_hf_model_bin=True, ) self.model_id = constants.LOCAL_BASE_MODEL_DIR self.finetuned_lora_model_path = os.environ.get( "FINETUNED_LORA_MODEL_PATH", "" ) if fileutils.is_gcs_path(self.finetuned_lora_model_path): fileutils.download_gcs_dir_to_local( self.finetuned_lora_model_path, constants.LOCAL_MODEL_DIR ) self.finetuned_lora_model_path = constants.LOCAL_MODEL_DIR logging.info( f"Using task:{self.task}, base model:{self.model_id}, lora model:" f" {self.finetuned_lora_model_path}, precision {self.precision_mode}." ) self.pipeline = None self.model = None self.tokenizer = None start_time = time.perf_counter() logging.info("Started PEFT handler initialization at: %s", start_time) if self.task == TEXT_TO_IMAGE_LORA: pipeline = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 ) logging.debug("Initialized the base model for text to image.") pipeline.scheduler = DPMSolverMultistepScheduler.from_config( pipeline.scheduler.config ) logging.debug("Initialized the scheduler for text to image.") # This is to reduce GPU memory requirements. pipeline.enable_xformers_memory_efficient_attention() pipeline = pipeline.to(self.map_location) # Reduces memory footprint. pipeline.enable_attention_slicing() if self.finetuned_lora_model_path: pipeline.load_lora_weights(self.finetuned_lora_model_path) logging.debug("Initialized the LoRA model for text to image.") self.pipeline = pipeline logging.info("Initialized the text to image pipelines.") elif self.task == SEQUENCE_CLASSIFICATION_LORA: tokenizer = AutoTokenizer.from_pretrained(self.model_id) logging.debug("Initialized the tokenizer for sequence classification.") model = AutoModelForSequenceClassification.from_pretrained( self.model_id, torch_dtype=torch.float16 ) logging.debug("Initialized the base model for sequence classification.") if self.finetuned_lora_model_path: model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path) logging.debug("Initialized the LoRA model for sequence classification.") model.to(self.map_location) self.model = model self.tokenizer = tokenizer elif ( self.task == CAUSAL_LANGUAGE_MODELING_LORA or self.task == INSTRUCT_LORA ): tokenizer = AutoTokenizer.from_pretrained( self.model_id, trust_remote_code=self.trust_remote_code, ) logging.debug("Initialized the tokenizer.") if self.task == CAUSAL_LANGUAGE_MODELING_LORA: if self.quantization == constants.AWQ: model = AutoAWQForCausalLM.from_quantized( self.model_id, trust_remote_code=self.trust_remote_code, ) elif self.quantization == constants.GPTQ or not self.quantization: if self.precision_mode == constants.PRECISION_MODE_32: model = AutoModelForCausalLM.from_pretrained( self.model_id, return_dict=True, torch_dtype=torch.float32, device_map="auto", trust_remote_code=self.trust_remote_code, ) elif self.precision_mode == constants.PRECISION_MODE_16B: model = AutoModelForCausalLM.from_pretrained( self.model_id, return_dict=True, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=self.trust_remote_code, ) elif self.precision_mode == constants.PRECISION_MODE_16: model = AutoModelForCausalLM.from_pretrained( self.model_id, return_dict=True, torch_dtype=torch.float16, device_map="auto", trust_remote_code=self.trust_remote_code, ) elif self.precision_mode == constants.PRECISION_MODE_8: quantization_config = BitsAndBytesConfig( load_in_8bit=True, int8_threshold=0 ) model = AutoModelForCausalLM.from_pretrained( self.model_id, return_dict=True, torch_dtype=torch.float16, device_map="auto", quantization_config=quantization_config, trust_remote_code=self.trust_remote_code, ) else: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( self.model_id, return_dict=True, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config, trust_remote_code=self.trust_remote_code, ) else: raise ValueError(f"Invalid QUANTIZATION value: {self.quantization}") else: try: model = AutoModelForCausalLM.from_pretrained( self.model_id, torch_dtype=torch.bfloat16, trust_remote_code=self.trust_remote_code, device_map="auto", ) except: # pylint: disable=bare-except model = AutoModelForCausalLM.from_pretrained( self.model_id, torch_dtype=torch.bfloat16, trust_remote_code=self.trust_remote_code, device_map="auto", ) logging.debug("Initialized the base model.") if self.finetuned_lora_model_path: model = PeftModel.from_pretrained(model, self.finetuned_lora_model_path) logging.debug("Initialized the LoRA model.") pipeline = transformers.pipeline( "text-generation", model=model, tokenizer=tokenizer, ) self.tokenizer = tokenizer self.pipeline = pipeline else: raise ValueError(f"Invalid TASK: {self.task}") self.initialized = True end_time = time.perf_counter() logging.info("The PEFT handler was initialize at: %s", end_time) logging.info("Handler initiation took %s seconds", end_time - start_time) def preprocess(self, data: Any) -> Any: """Preprocesses input data.""" # Assumes that the parameters are same in one request. We parse the # parameters from the first instance for all instances in one request. # For generation length: `max_length` defines the maximum length of the # sequence to be generated, including both input and output tokens. # `max_length` is overridden by `max_new_tokens` if also set. # `max_new_tokens` defines the maximum number of new tokens to generate, # ignoring the current number of tokens. # Reference: # https://github.com/huggingface/transformers/blob/574a5384557b1aaf98ddb13ea9eb0a0ee8ff2cb2/src/transformers/generation/configuration_utils.py#L69-L73 max_length = _MAX_LENGTH_DEFAULT max_tokens = _MAX_TOKENS_DEFAULT temperature = _TEMPERATURE_DEFAULT top_p = _TOP_P_DEFAULT top_k = _TOP_K_DEFAULT prompts = [item["prompt"] for item in data] if "max_length" in data[0]: max_length = data[0]["max_length"] if "max_tokens" in data[0]: max_tokens = data[0]["max_tokens"] if "temperature" in data[0]: temperature = data[0]["temperature"] if "top_p" in data[0]: top_p = data[0]["top_p"] if "top_k" in data[0]: top_k = data[0]["top_k"] return prompts, max_length, max_tokens, temperature, top_p, top_k def inference( self, data: Any, *args, **kwargs ) -> Tuple[List[str], List[Image.Image]]: """Runs the inference.""" prompts, max_length, max_tokens, temperature, top_p, top_k = data logging.debug( f"Inference prompts={prompts}, max_length={max_length}," f" max_tokens={max_tokens}, temperature={temperature}, top_p={top_p}," f" top_k={top_k}." ) if self.task == TEXT_TO_IMAGE_LORA: predicted_results = self.pipeline( prompt=prompts, num_inference_steps=_NUM_INFERENCE_STEPS ).images elif self.task == SEQUENCE_CLASSIFICATION_LORA: encoded_input = self.tokenizer(prompts, return_tensors="pt", padding=True) encoded_input.to(self.map_location) with torch.no_grad(): outputs = self.model(**encoded_input) predictions = outputs.logits.argmax(dim=-1) predicted_results = predictions.tolist() elif ( self.task == CAUSAL_LANGUAGE_MODELING_LORA or self.task == INSTRUCT_LORA ): predicted_results = self.pipeline( prompts, max_length=max_length, max_new_tokens=max_tokens, do_sample=True, temperature=temperature, top_p=top_p, top_k=top_k, num_return_sequences=1, eos_token_id=self.tokenizer.eos_token_id, return_full_text=False, ) else: raise ValueError(f"Invalid TASK: {self.task}") return prompts, predicted_results def postprocess(self, data: Any) -> List[str]: """Postprocesses output data.""" prompts, predicted_results = data if self.task == TEXT_TO_IMAGE_LORA: # Converts the images to base64 string. outputs = [ image_format_converter.image_to_base64(image) for image in predicted_results ] elif self.task == SEQUENCE_CLASSIFICATION_LORA: outputs = predicted_results else: outputs = [] for prompt, predicted_result in zip(prompts, predicted_results): formatted_output = self._format_text_generation_output( prompt=prompt, output=predicted_result[0]["generated_text"] ) outputs.append(formatted_output) return outputs def _format_text_generation_output(self, prompt: str, output: str) -> str: """Formats text generation output.""" output = output.strip("\n") return f"Prompt:\n{prompt.strip()}\nOutput:\n{output}" # pylint: enable=logging-fstring-interpolation