server/text_generation_server/models/custom_modeling/gemma3/processing_gemma3.py (137 lines of code) (raw):

# coding=utf-8 # Copyright 2025 Google Inc. HuggingFace Inc. team. All rights reserved. # # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from typing import List, Optional, Union from transformers.feature_extraction_utils import BatchFeature from transformers.image_utils import ImageInput from transformers.processing_utils import ( ImagesKwargs, ProcessingKwargs, ProcessorMixin, Unpack, ) from transformers.tokenization_utils_base import PreTokenizedInput, TextInput from transformers.utils import to_py_obj from text_generation_server.models.custom_modeling.gemma3.image_processing_gemma3 import ( Gemma3ImageProcessor, ) from transformers.image_utils import PILImageResampling from .utils import make_nested_list_of_images class Gemma3ImagesKwargs(ImagesKwargs): do_pan_and_scan: Optional[bool] pan_and_scan_min_crop_size: Optional[int] pan_and_scan_max_num_crops: Optional[int] pan_and_scan_min_ratio_to_activate: Optional[float] do_convert_rgb: Optional[bool] class Gemma3ProcessorKwargs(ProcessingKwargs, total=False): _defaults = { "text_kwargs": { "padding": False, }, "images_kwargs": { "do_pan_and_scan": False, "pan_and_scan_min_crop_size": 256, "pan_and_scan_max_num_crops": 4, "pan_and_scan_min_ratio_to_activate": 1.2, }, } class Gemma3Processor(ProcessorMixin): attributes = ["image_processor", "tokenizer"] valid_kwargs = ["chat_template"] # # image_processor_class = "Gemma3ImageProcessor" image_processor_class = "AutoProcessor" tokenizer_class = "AutoTokenizer" def __init__( self, image_processor, tokenizer, chat_template=None, num_mm_soft_tokens_per_image: int = 256, **kwargs, ): num_mm_soft_tokens_per_image = 256 chat_template = None image_processor = Gemma3ImageProcessor( image_mean=(127.5,) * 3, image_std=(127.5,) * 3, size={"height": 896, "width": 896}, do_rescale=False, resample=PILImageResampling.BILINEAR, ) self.image_token_id = tokenizer.image_token_id image_tokens_expanded = "".join( [tokenizer.image_token] * num_mm_soft_tokens_per_image ) self.full_image_sequence = ( f"\n\n{tokenizer.boi_token}{image_tokens_expanded}{tokenizer.eoi_token}\n\n" ) self.image_processor = image_processor self.tokenizer = tokenizer self.chat_template = chat_template # super().__init__( # image_processor=image_processor, # tokenizer=tokenizer, # chat_template=chat_template, # **kwargs, # ) def __call__( self, images: ImageInput = None, text: Union[ TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput] ] = None, videos=None, audio=None, **kwargs: Unpack[Gemma3ProcessorKwargs], ) -> BatchFeature: if text is None and images is None: raise ValueError("Provide at least one of `text` or `images`.") output_kwargs = self._merge_kwargs( Gemma3ProcessorKwargs, tokenizer_init_kwargs=self.tokenizer.init_kwargs, **kwargs, ) if isinstance(text, str): text = [text] elif not isinstance(text, list) and not isinstance(text[0], str): raise ValueError( "Invalid input text. Please provide a string, or a list of strings" ) image_inputs = {} if images is not None: batched_images = make_nested_list_of_images(images) image_inputs = self.image_processor( batched_images, **output_kwargs["images_kwargs"] ) # Create empty text to be replaced with placeholders if not text: text = [ " ".join(["<image>"] * len(images)) for images in batched_images ] if len(batched_images) != len(text): raise ValueError( f"Received inconsistently sized batches of images ({len(batched_images)}) and text ({len(text)})." ) # Replace image tokens by the full expanded sequence batch_num_crops = to_py_obj(image_inputs.pop("num_crops")) for prompt, images, num_crops in zip(text, batched_images, batch_num_crops): image_indexes = [m.start() for m in re.finditer("<image>", prompt)] if len(images) != len(image_indexes): raise ValueError( f"Prompt contained {len(image_indexes)} image tokens but received {len(images)} images." ) # Insert additional image tokens for Pan-and-Scan crops for num, idx in reversed(list(zip(num_crops, image_indexes))): if num: formatted_image_text = ( "Here is the original image <image> and here are some crops to help you see better " + " ".join(["<image>"] * num) ) prompt = ( prompt[:idx] + formatted_image_text + prompt[idx + len("<image>") :] ) # Expand placeholder image tokens to the full image token sequence text = [ prompt.replace("<image>", self.full_image_sequence) for prompt in text ] text_input = self.tokenizer(text=text, **output_kwargs["text_kwargs"]) return BatchFeature(data={**text_input, **image_inputs}) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Gemma def batch_decode(self, *args, **kwargs): """ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.batch_decode(*args, **kwargs) # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Gemma def decode(self, *args, **kwargs): """ This method forwards all its arguments to GemmaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to the docstring of this method for more information. """ return self.tokenizer.decode(*args, **kwargs) @property # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names with CLIP->PaliGemma def model_input_names(self): tokenizer_input_names = self.tokenizer.model_input_names image_processor_input_names = self.image_processor.model_input_names return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) __all__ = ["Gemma3Processor"]