maga_transformer/utils/multimodal_util.py (94 lines of code) (raw):

import os import torch import json import requests import threading from enum import IntEnum from io import BytesIO from typing import Any, Callable, Optional from PIL import Image from dataclasses import dataclass, field from maga_transformer.utils.lru_dict import LruDict from maga_transformer.utils.oss_util import get_bytes_io_from_oss_path if os.environ.get('DOWNLOAD_HEADERS', '') != '': HTTP_HEADS = json.loads(os.environ['DOWNLOAD_HEADERS']) else: HTTP_HEADS = { 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0', 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7', } class MMUrlType(IntEnum): DEFAULT = 0 IMAGE = 1 VIDEO = 2 AUDIO = 3 TENSOR = 4 @dataclass class MMPreprocessConfig: width: int = -1 height: int = -1 min_pixels: int = -1 max_pixels: int = -1 fps: int = -1 min_frames: int = -1 max_frames: int = -1 class MultimodalInput: url: str mm_type: MMUrlType config: MMPreprocessConfig tensor: torch.Tensor def __init__(self, url: str, mm_type: MMUrlType=MMUrlType.DEFAULT, config: MMPreprocessConfig=MMPreprocessConfig(), tensor: torch.Tensor=torch.empty(1)): self.url = url self.mm_type = mm_type self.config = config self.tensor = tensor def get_vit_compute_dtype(dtype: str): if dtype == "bf16": return torch.bfloat16 else: return torch.half def get_bytes_io_from_url(url: str): cached_res = url_data_cache_.check_cache(url) if cached_res is None: try: if url.startswith("http") or url.startswith("https"): response = requests.get(url, stream=True, headers=HTTP_HEADS, timeout=10) if response.status_code == 200: res = BytesIO(response.content) else: raise Exception(f'download failed, error code: {response.status_code}') elif url.startswith("oss"): res = get_bytes_io_from_oss_path(url) else: # treat url as local path with open(url, "rb") as fh: buf = BytesIO(fh.read()) res = buf except Exception as e: raise Exception(f"download and load {url} error, exception {e}") url_data_cache_.insert_cache(url, res) return res else: cached_res.seek(0) return cached_res class MMDataCache(object): def __init__(self, cache_size: int = 10): self.mm_data_cache: Optional[LruDict] = None self.cache_lock = threading.Lock() if cache_size > 0: self.mm_data_cache = LruDict(cache_size) def check_cache(self, url: str): if self.mm_data_cache == None: return None with self.cache_lock: if url in self.mm_data_cache: return self.mm_data_cache[url] else: return None def insert_cache(self, url: str, features: torch.Tensor): if self.mm_data_cache == None: return with self.cache_lock: self.mm_data_cache[url] = features vit_emb_cache_ = MMDataCache(int(os.environ.get('MM_CACHE_ITEM_NUM', '0'))) url_data_cache_ = MMDataCache(100)