vision/smolvlm2/smolvlm/mm_utils.py (206 lines of code) (raw):

import os import re import math import random import base64 import logging from io import BytesIO from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch import ujson as json import yaml import transformers import torchvision from PIL import Image, ImageFile from torch.utils.data import Dataset from torchvision import io, transforms from torchvision.transforms import InterpolationMode ImageFile.LOAD_TRUNCATED_IMAGES = True Image.MAX_IMAGE_PIXELS = 1000000000 logger = logging.getLogger(__name__) ############################################################################### # Basic helper logic: rounding, resizing, frames, etc. ############################################################################### def round_by_factor(number: float, factor: int) -> int: """Round 'number' to the nearest integer multiple of 'factor'.""" return round(number / factor) * factor def ceil_by_factor(number: float, factor: int) -> int: """Ceil 'number' to the nearest integer multiple of 'factor'.""" return math.ceil(number / factor) * factor def floor_by_factor(number: float, factor: int) -> int: """Floor 'number' to the nearest integer multiple of 'factor'.""" return math.floor(number / factor) * factor def smart_resize( height: int, width: int, factor: int, min_pixels: int, max_pixels: int, max_ratio: float, ) -> Tuple[int, int]: """ Rescale (height, width) so that: - aspect ratio <= max_ratio, - total area in [min_pixels, max_pixels], - each dimension is multiple of factor. """ ratio = max(height, width) / min(height, width) if ratio > max_ratio: raise ValueError(f"Aspect ratio {ratio:.2f} > {max_ratio}") h_ = max(factor, round_by_factor(height, factor)) w_ = max(factor, round_by_factor(width, factor)) area = h_ * w_ if area > max_pixels: scale = math.sqrt((height * width) / max_pixels) h_ = floor_by_factor(height / scale, factor) w_ = floor_by_factor(width / scale, factor) elif area < min_pixels: scale = math.sqrt(min_pixels / (height * width)) h_ = ceil_by_factor(height * scale, factor) w_ = ceil_by_factor(width * scale, factor) return h_, w_ def _smart_nframes( config: Dict[str, Any], total_frames: int, video_fps: float, frame_factor: int, default_fps: float, fps_min_frames: int, fps_max_frames: int ) -> int: """ Decide how many frames to pick from a video based on either: - 'nframes' in config - or 'fps' in config (or default_fps if none specified). Result is clamped to [fps_min_frames, fps_max_frames], and must be multiple of 'frame_factor'. """ if "nframes" in config and "fps" in config: raise ValueError("Provide only one of `fps` or `nframes` in config.") if "nframes" in config: nframes = round_by_factor(config["nframes"], frame_factor) else: user_fps = config.get("fps", default_fps) minf = ceil_by_factor(config.get("min_frames", fps_min_frames), frame_factor) maxf = floor_by_factor(config.get("max_frames", min(fps_max_frames, total_frames)), frame_factor) val = total_frames / video_fps * user_fps val = min(max(val, minf), maxf) nframes = round_by_factor(val, frame_factor) if not (frame_factor <= nframes <= total_frames): raise ValueError(f"Invalid nframes={nframes}, out of range.") return int(nframes) def _read_video_torchvision( config: Dict[str, Any], frame_factor: int, default_fps: float, fps_min_frames: int, fps_max_frames: int ) -> torch.Tensor: """ Use torchvision.io.read_video to read and return a TCHW video tensor. """ path = config["video"] vid, _, info = io.read_video( path, start_pts=config.get("video_start", 0.0), end_pts=config.get("video_end", None), pts_unit="sec", output_format="TCHW", ) total_frames = vid.size(0) video_fps = info["video_fps"] nframes = _smart_nframes(config, total_frames, video_fps, frame_factor, default_fps, fps_min_frames, fps_max_frames) idx = torch.linspace(0, total_frames - 1, nframes).round().long() return vid[idx] def _read_video_decord( config: Dict[str, Any], frame_factor: int, default_fps: float, fps_min_frames: int, fps_max_frames: int ) -> torch.Tensor: """ Use decord to read and return a TCHW video tensor. """ import decord path = config["video"] vr = decord.VideoReader(path) total_frames = len(vr) video_fps = vr.get_avg_fps() nframes = _smart_nframes(config, total_frames, video_fps, frame_factor, default_fps, fps_min_frames, fps_max_frames) idx = torch.linspace(0, total_frames - 1, nframes).round().long().tolist() arr = vr.get_batch(idx).asnumpy() # T,H,W,C return torch.from_numpy(arr).permute(0, 3, 1, 2) # -> T,C,H,W VIDEO_READERS = { "torchvision": _read_video_torchvision, "decord": _read_video_decord, } def pick_video_reader() -> str: """Pick decord if installed, otherwise torchvision.""" try: import importlib.util if importlib.util.find_spec("decord") is not None: return "decord" except: pass return "torchvision" ############################################################################### # Multimodal fetch_image / fetch_video ############################################################################### def _fetch_image( config: Dict[str, Any], image_factor: int, min_pixels: int, max_pixels: int, max_ratio: float ) -> Image.Image: """ Load a single image (from local path, URL, or base64) and resize it via 'smart_resize' constraints. """ source = config.get("image") or config.get("image_url") if not source: raise ValueError("Must provide either 'image' or 'image_url' in config.") # Load if isinstance(source, Image.Image): pil_img = source elif isinstance(source, str): if source.startswith("http://") or source.startswith("https://"): import requests pil_img = Image.open(requests.get(source, stream=True).raw) elif source.startswith("file://"): pil_img = Image.open(source[7:]) elif source.startswith("data:image"): # base64 data if "base64," in source: _, b64_data = source.split("base64,", 1) raw = base64.b64decode(b64_data) pil_img = Image.open(BytesIO(raw)) else: raise ValueError("Invalid base64 image data.") else: # local path pil_img = Image.open(source) else: raise ValueError(f"Unsupported type for 'image': {type(source)}") pil_img = pil_img.convert("RGB") # Resize if "resized_height" in config and "resized_width" in config: rh, rw = smart_resize(config["resized_height"], config["resized_width"], image_factor, min_pixels, max_pixels, max_ratio) else: # infer dims from the image w, h = pil_img.size local_min = config.get("min_pixels", min_pixels) local_max = config.get("max_pixels", max_pixels) rh, rw = smart_resize(h, w, image_factor, local_min, local_max, max_ratio) # Return the resized image return pil_img.resize((rw, rh)) def _fetch_video( config: Dict[str, Any], image_factor: int, min_pixels: int, max_pixels: int, max_ratio: float, video_total_pixels: int, frame_factor: int, default_fps: float, fps_min_frames: int, fps_max_frames: int ) -> Union[torch.Tensor, List[Image.Image]]: """ If config['video'] is a str => read entire video => TCHW tensor, If config['video'] is a list => treat them as frame paths => list of PIL Images. """ val = config["video"] if isinstance(val, str): # Single video path backend = pick_video_reader() fn = VIDEO_READERS[backend] vid_tensor = fn(config, frame_factor, default_fps, fps_min_frames, fps_max_frames) # shape => T, C, oh, ow t, c, oh, ow = vid_tensor.shape local_min = config.get("min_pixels", min_pixels) local_max = config.get("max_pixels", max_pixels) local_total = config.get("total_pixels", video_total_pixels) guess_max = max(min(local_max, local_total / t * frame_factor), int(local_min * 1.05)) if "resized_height" in config and "resized_width" in config: rh, rw = smart_resize(config["resized_height"], config["resized_width"], image_factor, local_min, guess_max, max_ratio) else: rh, rw = smart_resize(oh, ow, image_factor, local_min, guess_max, max_ratio) # Resize frames vid_tensor = transforms.functional.resize( vid_tensor, [rh, rw], interpolation=InterpolationMode.BICUBIC, antialias=True ).float() return vid_tensor elif isinstance(val, list): # List of frame paths frames = [] meta = dict(config) meta.pop("video", None) for fp in val: meta["image"] = fp frames.append(_fetch_image(meta, image_factor, min_pixels, max_pixels, max_ratio)) # Possibly pad frames to multiple of frame_factor needed = ceil_by_factor(len(frames), frame_factor) if len(frames) < needed and len(frames) > 0: frames += [frames[-1]] * (needed - len(frames)) return frames else: raise ValueError(f"'video' must be a str or list, got {type(val)}") def tokenizer_image_token(prompt, tokenizer, return_tensors=None): return tokenizer(prompt, return_tensors=return_tensors).input_ids[0]