community-content/vertex_model_garden/model_oss/transformers/handler.py (195 lines of code) (raw):

"""Custom handler for huggingface/transformers models.""" # pylint: disable=g-multiple-import # pylint: disable=g-importing-member import logging import os from typing import Any, List, Optional, Tuple from PIL import Image import torch from transformers import ( AutoProcessor, AutoTokenizer, Blip2ForConditionalGeneration, Blip2Processor, BlipForConditionalGeneration, BlipForQuestionAnswering, BlipProcessor, CLIPModel, ) from transformers import pipeline from ts.torch_handler.base_handler import BaseHandler from util import constants from util import fileutils from util import image_format_converter DEFAULT_MODEL_ID = "openai/clip-vit-base-patch32" SALESFORCE_BLIP = "Salesforce/blip" SALESFORCE_BLIP2 = "Salesforce/blip2" FLAN_T5 = "flan-t5" BART_LARGE_CNN = "facebook/bart-large-cnn" ZERO_CLASSIFICATION = "zero-shot-image-classification" FEATURE_EMBEDDING = "feature-embedding" ZERO_DETECTION = "zero-shot-object-detection" IMAGE_CAPTIONING = "image-to-text" VQA = "visual-question-answering" DQA = "document-question-answering" SUMMARIZATION = "summarization" SUMMARIZATION_TEMPLATE = ( "Summarize the following news article:\n{input}\nSummary:\n" ) class TransformersHandler(BaseHandler): """Custom handler for huggingface/transformers models.""" def initialize(self, context: Any): """Custom initialize.""" properties = context.system_properties self.map_location = ( "cuda" if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" ) self.device = torch.device( self.map_location + ":" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else self.map_location ) self.manifest = context.manifest # The model id is can be either: # 1) a huggingface model card id, like "Salesforce/blip", or # 2) a GCS path to the model files, like "gs://foo/bar". # If it's a model card id, the model will be loaded from huggingface. self.model_id = ( DEFAULT_MODEL_ID if os.environ.get("MODEL_ID") is None else os.environ["MODEL_ID"] ) # Else it will be downloaded from GCS to local first. # Since the transformers from_pretrained API can't read from GCS. if self.model_id.startswith(constants.GCS_URI_PREFIX): gcs_path = self.model_id[len(constants.GCS_URI_PREFIX) :] local_model_dir = os.path.join(constants.LOCAL_MODEL_DIR, gcs_path) logging.info("Download %s to %s", self.model_id, local_model_dir) fileutils.download_gcs_dir_to_local(self.model_id, local_model_dir) self.model_id = local_model_dir self.task = ( ZERO_CLASSIFICATION if os.environ.get("TASK") is None else os.environ["TASK"] ) logging.info( "Handler initializing task:%s, model:%s", self.task, self.model_id ) if SALESFORCE_BLIP in self.model_id: # pipeline() hasn't been ready for Salesforce/blip models. self.salesforce_blip = True self._create_blip_model() else: self.salesforce_blip = False if self.task == FEATURE_EMBEDDING: self.model = CLIPModel.from_pretrained(self.model_id).to( self.map_location ) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.processor = AutoProcessor.from_pretrained(self.model_id) elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id: self.pipeline = pipeline( task=self.task, model=self.model_id, torch_dtype=torch.bfloat16, device_map="auto", ) else: self.pipeline = pipeline( task=self.task, model=self.model_id, device=self.device ) self.initialized = True logging.info("Handler initialization done.") def _create_blip_model(self): """A helper for creating BLIP and BLIP2 models.""" if SALESFORCE_BLIP2 in self.model_id: self.torch_type = torch.float16 self.processor = Blip2Processor.from_pretrained(self.model_id) self.model = Blip2ForConditionalGeneration.from_pretrained( self.model_id, torch_dtype=self.torch_type ).to(self.map_location) else: self.torch_type = torch.float32 self.processor = BlipProcessor.from_pretrained(self.model_id) if self.task == IMAGE_CAPTIONING: self.model = BlipForConditionalGeneration.from_pretrained( self.model_id ).to(self.map_location) elif self.task == VQA: self.model = BlipForQuestionAnswering.from_pretrained(self.model_id).to( self.map_location ) def _reformat_detection_result(self, data: List[Any]) -> List[Any]: """Reformat zero-shot-object-detection output.""" if not data: return [data] boxes = {} boxes["label"] = data[0]["label"] boxes["boxes"] = [] for item in data: box = {} box["score"] = item["score"] box.update(item["box"]) boxes["boxes"].append(box) outputs = [boxes] return outputs def preprocess( self, data: Any ) -> Tuple[Optional[List[str]], Optional[List[Image.Image]]]: """Preprocess input data.""" texts = None images = None if "text" in data[0]: texts = [item["text"] for item in data] if "image" in data[0]: images = [ image_format_converter.base64_to_image(item["image"]) for item in data ] return texts, images def inference(self, data: Any, *args, **kwargs) -> List[Any]: """Run the inference.""" texts, images = data preds = None if self.task == ZERO_CLASSIFICATION: preds = self.pipeline(images=images, candidate_labels=texts) elif self.task == ZERO_DETECTION: # The object detection pipeline doesn't support batch prediction. preds = self.pipeline(image=images[0], candidate_labels=texts[0]) elif self.task == IMAGE_CAPTIONING: if self.salesforce_blip: inputs = self.processor(images[0], return_tensors="pt").to( self.map_location, self.torch_type ) preds = self.model.generate(**inputs) preds = [ self.processor.decode(preds[0], skip_special_tokens=True).strip() ] else: preds = self.pipeline(images=images) elif self.task == VQA: # The VQA pipelines doesn't support batch prediction. if self.salesforce_blip: inputs = self.processor(images[0], texts[0], return_tensors="pt").to( self.map_location, self.torch_type ) preds = self.model.generate(**inputs) preds = [ self.processor.decode(preds[0], skip_special_tokens=True).strip() ] else: preds = self.pipeline(image=images[0], question=texts[0]) elif self.task == DQA: # The DQA pipelines doesn't support batch prediction. preds = self.pipeline(image=images[0], question=texts[0]) elif self.task == FEATURE_EMBEDDING: preds = {} if texts: inputs = self.tokenizer( text=texts, padding=True, return_tensors="pt" ).to(self.map_location) text_features = self.model.get_text_features(**inputs) preds["text_features"] = text_features.detach().cpu().numpy().tolist() if images: inputs = self.processor(images=images, return_tensors="pt").to( self.map_location ) image_features = self.model.get_image_features(**inputs) preds["image_features"] = image_features.detach().cpu().numpy().tolist() preds = [preds] elif self.task == SUMMARIZATION and FLAN_T5 in self.model_id: texts = [SUMMARIZATION_TEMPLATE.format(input=text) for text in texts] preds = self.pipeline(texts, max_length=130) elif self.task == SUMMARIZATION and self.model_id == BART_LARGE_CNN: preds = self.pipeline( texts[0], max_length=130, min_length=30, do_sample=False ) else: raise ValueError(f"Invalid TASK: {self.task}") return preds def postprocess(self, data: Any) -> List[Any]: if self.task == ZERO_DETECTION: data = self._reformat_detection_result(data) return data