vision/smolvlm2/smolvlm/model/processing_smollmm.py (133 lines of code) (raw):
import re
from itertools import accumulate
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from transformers.utils import logging
from transformers.feature_extraction_utils import BatchFeature
from transformers.image_utils import ImageInput, is_valid_image, load_image
from transformers.processing_utils import Unpack
from transformers.tokenization_utils_base import TextInput, BatchEncoding
# Import the *parent* class
from transformers.models.idefics3.processing_idefics3 import (
Idefics3Processor,
Idefics3ProcessorKwargs,
is_url,
is_image_or_image_url,
get_image_prompt_string,
)
if TYPE_CHECKING:
from transformers.tokenization_utils_base import PreTokenizedInput
logger = logging.get_logger(__name__)
class SmolLMMProcessor(Idefics3Processor):
"""
A subclass of Idefics3Processor that adds an `allow_mismatch` argument
to skip the 1:1 match check between #<image> tokens and #images.
"""
def __call__(
self,
images: Union[ImageInput, List[ImageInput], List[List[ImageInput]]] = None,
text: Union[TextInput, "PreTokenizedInput", List[TextInput], List["PreTokenizedInput"]] = None,
audio=None,
videos=None,
image_seq_len: Optional[int] = None,
allow_mismatch: bool = False, # <--- NEW ARG
**kwargs: Unpack[Idefics3ProcessorKwargs],
) -> BatchEncoding:
"""
Process input for Idefics3. If `allow_mismatch=True`, we skip the error when
#<image> tokens != #images.
See `Idefics3Processor.__call__` docstring for details on the other params.
"""
if text is None and images is None:
raise ValueError("You must provide either `text` or `images`.")
# Merge default keyword args for text/images
output_kwargs = self._merge_kwargs(
Idefics3ProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
# If no override for image_seq_len is passed, use the default self.image_seq_len
image_seq_len = image_seq_len if image_seq_len is not None else self.image_seq_len
# Count how many <image> tokens are in each text sample
n_images_in_text = []
if text is not None:
if isinstance(text, str):
text = [text]
n_images_in_text = [sample.count(self.image_token.content) for sample in text]
inputs = BatchFeature()
# ---------------------------------------------------------------
# If images are provided, do all the logic that normally raises a mismatch error.
# We'll skip or warn if allow_mismatch is True.
# ---------------------------------------------------------------
if images is not None:
# Flatten or interpret images
if is_image_or_image_url(images):
images = [[images]]
elif isinstance(images, list) and is_image_or_image_url(images[0]):
# Original code raises error if mismatch. We'll skip if allow_mismatch.
if not allow_mismatch and text is not None and sum(n_images_in_text) != len(images):
raise ValueError(
f"The total number of <image> tokens in the prompts should match "
f"the number of images. Found {sum(n_images_in_text)} <image> tokens "
f"but {len(images)} images."
)
else:
if text is not None and sum(n_images_in_text) != len(images):
logger.warning(
"Mismatch #<image> tokens vs. #images, but allow_mismatch=True => continuing."
)
# Re-group images to match text samples
# if text is not None:
# cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
# images = [
# images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
# for i in range(len(n_images_in_text))
# ]
if text is not None:
# Calculate frames per token
total_images = len(images)
total_tokens = sum(n_images_in_text)
if total_images > total_tokens and total_images % total_tokens == 0:
frames_per_token = total_images // total_tokens
# Create new grouping that preserves consecutive frames
new_images = []
for i in range(len(n_images_in_text)):
start_idx = i * frames_per_token * n_images_in_text[i]
end_idx = start_idx + (frames_per_token * n_images_in_text[i])
new_images.append(images[start_idx:end_idx])
images = new_images
else:
# Original regrouping logic for other cases
cumsum_images_in_text = [0] + list(accumulate(n_images_in_text))
images = [
images[cumsum_images_in_text[i] : cumsum_images_in_text[i + 1]]
for i in range(len(n_images_in_text))
]
else:
images = [images]
elif (
not isinstance(images, list)
and not isinstance(images[0], list)
and not is_image_or_image_url(images[0][0])
):
raise ValueError("Invalid input images. Provide image or list of images or list of list of images.")
n_images_in_images = [len(sample) for sample in images]
# Actually load images if they are URLs
images = [[load_image(im) if is_url(im) else im for im in sample] for sample in images]
# Let the parent's image_processor handle shape, resizing, etc.
image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"])
inputs.update(image_inputs)
# If we have text, handle expansions
if text is not None:
if not allow_mismatch and n_images_in_images != n_images_in_text:
raise ValueError(
f"Mismatch in #images vs #<image> tokens. We found {n_images_in_text} <image> tokens "
f"but have {n_images_in_images} images in each batch."
)
else:
if n_images_in_images != n_images_in_text:
logger.warning(
"Mismatch in #images vs #<image> tokens, but allow_mismatch=True => continuing."
)
# Rows/cols for expanded patch tokens
image_rows = inputs.pop("rows", [[0] * len(text)])
image_cols = inputs.pop("cols", [[0] * len(text)])
fake_image_token = self.fake_image_token.content
image_token = self.image_token.content
global_img_token = self.global_image_tag
prompt_strings = []
for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols):
image_prompt_strings = []
for n_rows, n_cols in zip(sample_rows, sample_cols):
image_prompt_string = get_image_prompt_string(
n_rows,
n_cols,
image_seq_len,
image_token=image_token,
fake_token_around_image=fake_image_token,
global_img_token=global_img_token,
)
image_prompt_strings.append(image_prompt_string)
split_sample = sample.split(image_token)
if len(split_sample) == 0:
raise ValueError("Expected <image> token in text, found none.")
# Insert expansions for each <image> placeholder
combined = split_sample[0]
# for i, image_prompt_string in enumerate(image_prompt_strings):
# combined += image_prompt_string + split_sample[i + 1]
for i, split_subsample in enumerate(split_sample[1:]):
combined += image_prompt_strings[i-1] + split_subsample
prompt_strings.append(combined)
# Now tokenize the text with expansions
text_inputs = self.tokenizer(text=prompt_strings, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
# -------------------------------------------------------------------
# If we have text only (no images)
# -------------------------------------------------------------------
elif text is not None:
# no images => zero <image> tokens
if any(n_images_in_text):
raise ValueError(
f"Found {sum(n_images_in_text)} <image> tokens in text, but no images were passed."
)
text_inputs = self.tokenizer(text=text, **output_kwargs["text_kwargs"])
inputs.update(text_inputs)
return inputs
# batch_decode, decode, model_input_names remain the same as parent
# If you want them identical, no need to override them.