megatron_patch/data/image_processing.py (67 lines of code) (raw):

# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. Except portions as noted which are Copyright (c) 2023 OpenGVLab and licensed under the MIT license found in LICENSE. import numpy as np import torch import math import random from PIL import Image, ImageDraw from torchvision import transforms as T from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage # Imagenet's mean and std. pixel_mean = [123.675, 116.28, 103.53] pixel_std = [58.395, 57.12, 57.375] # Reshape for broadcasting. pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1) pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1) def convert_to_rgb(image): return image.convert("RGB") def _transform_train_aug(): return Compose([ ToPILImage(), Resize(scale=random.random() / 2 + 0.5), convert_to_rgb, RandAugment(2, 5, isPIL=True, augs=['Identity', 'AutoContrast', 'Brightness', 'Sharpness', 'Equalize', 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate']), ]) def _transform_test(): return Compose([ ToPILImage(), convert_to_rgb, ]) def standardize_image(img): """Standardize image pixel values.""" return (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std def get_visual_transform( img, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280, augment=False ): img = np.array(img) if augment: visual_transform = _transform_train_aug() else: visual_transform = _transform_test() img = visual_transform(img) w, h = img.size h_bar, w_bar = smart_resize(h, w, factor, min_pixels, max_pixels) img = img.resize((w_bar, h_bar)) # Standardize pixel values. img = standardize_image(img) imgs = [img] return imgs # copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/image_processing_qwen2_vl.py def smart_resize( height: int, width: int, factor: int = 28, min_pixels: int = 56 * 56, max_pixels: int = 14 * 14 * 4 * 1280 ): """Rescales the image so that the following conditions are met: 1. Both dimensions (height and width) are divisible by 'factor'. 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. 3. The aspect ratio of the image is maintained as closely as possible. """ if height < factor or width < factor: raise ValueError(f"height:{height} or width:{width} must be larger than factor:{factor}") elif max(height, width) / min(height, width) > 200: raise ValueError( f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" ) h_bar = round(height / factor) * factor w_bar = round(width / factor) * factor if h_bar * w_bar > max_pixels: beta = math.sqrt((height * width) / max_pixels) h_bar = math.floor(height / beta / factor) * factor w_bar = math.floor(width / beta / factor) * factor elif h_bar * w_bar < min_pixels: beta = math.sqrt(min_pixels / (height * width)) h_bar = math.ceil(height * beta / factor) * factor w_bar = math.ceil(width * beta / factor) * factor return h_bar, w_bar