vision/smolvlm2/smolvlm/datasets/dataset.py (530 lines of code) (raw):
import os
import math
import random
import time
import copy
import json
import logging
from typing import Any, Dict, List, Optional, Tuple
import decord
from decord import VideoReader
decord.bridge.set_bridge("torch")
from num2words import num2words
import datetime
import re
import torch
import numpy as np
import transformers
from torch.utils.data import Dataset
from PIL import Image, ImageFile
from smolvlm.constants import (
IGNORE_INDEX,
DATA_IMAGE_TOKEN,
DEFAULT_IMAGE_TOKEN,
DATA_VIDEO_TOKEN,
DEFAULT_VIDEO_TOKEN,
)
from smolvlm.train.args import DataArguments, TrainingArguments, ModelArguments
from smolvlm.utils import mprint
logger = logging.getLogger(__name__)
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = 1000000000
##############################################################################
# helper functions
##############################################################################
# Video Loader
##############################################################################
import logging
from typing import List, Tuple
import decord
from decord import VideoReader, cpu
import numpy as np
from PIL import Image
logger = logging.getLogger(__name__)
DEFAULT_SYSTEM_MESSAGE = "You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
# DEFAULT_VIDEO_INTRO = "Here are some frames sampled from a video:"
DEFAULT_VIDEO_INTRO = (
"You are provided the following series of {frame_count} frames "
"from a {video_duration} [H:MM:SS] video.\n"
)
DEFAULT_IMAGE_INTRO = "Here are some images: "
DEFAULT_MEDIA_OUTTRO = "Now answer the following question: "
FRAME_TIMESTAMP_MESSAGE = "Frame from"
def load_video(
path: str,
max_frames: int = 100,
target_fps: float = 2.0,
skip_secs: float = 1.0
) -> Tuple[List[Image.Image], List[str]]:
"""
Loads a video from `path` using decord, sampling up to `max_frames` frames.
After deduplicating indices (e.g., to handle rounding collisions), each frame
is decoded into a PIL Image (in RGB mode). Timestamps are generated in "MM:SS" format
based on the frame index over `native_fps`.
Args:
path (str): Path to the video file (e.g., MP4).
max_frames (int): Hard cap on how many frames we ever pick in total.
target_fps (float): Target approximate sampling rate in frames per second.
skip_secs (float): Number of seconds to skip at the beginning and end if
the video is sufficiently long ((duration - 2*skip_secs) > max_frames * target_fps).
Returns:
Tuple[List[Image.Image], List[str]]:
- A list of PIL.Image objects corresponding to each selected frame.
- A list of parallel timestamps ("MM:SS" strings), one per selected frame.
"""
try:
# Use decord with single-thread and CPU context
vr = VideoReader(path, num_threads=1, ctx=cpu(0))
except Exception as e:
raise RuntimeError(f"Failed to open video '{path}': {e}")
total_frames = len(vr)
if total_frames == 0:
raise RuntimeError(f"Video '{path}' has 0 frames.")
# Fallback to 30 if native_fps is None or zero
native_fps = vr.get_avg_fps() or 30.0
duration_seconds = total_frames / native_fps
# Estimate how many frames we'd get if we sample at `target_fps`.
estimated_frames = int(round(target_fps * duration_seconds)) if target_fps > 0 else max_frames
desired_frames = min(estimated_frames, max_frames)
if desired_frames < 1:
desired_frames = 1
start_idx = 0
end_idx = total_frames - 1
# Centered skip if we want fewer frames than max_frames
if desired_frames < max_frames:
leftover = total_frames - desired_frames
start_idx = leftover // 2
end_idx = total_frames - (leftover - start_idx)
# Otherwise, if video is long enough, skip a bit from start and end
elif skip_secs > 0 and (duration_seconds - 2 * skip_secs) > (max_frames * target_fps):
start_idx = int(skip_secs * native_fps)
end_idx = int(total_frames - skip_secs * native_fps)
# Ensure valid start / end
start_idx = max(start_idx, 0)
end_idx = min(end_idx, total_frames - 1)
if start_idx >= end_idx:
start_idx = 0
end_idx = total_frames - 1
# Uniformly sample the desired number of frames from [start_idx..end_idx]
frames_idx = np.linspace(start_idx, end_idx, desired_frames, dtype=int)
frames_idx = np.unique(frames_idx).tolist()
# Read frames from decord
try:
frames_tensor = vr.get_batch(frames_idx).cpu().numpy() # (N, H, W, C)
except Exception as e:
raise RuntimeError(f"Failed to read frames from '{path}': {e}")
# Convert to PIL Images
frames_out = [Image.fromarray(arr).convert("RGB") for arr in frames_tensor]
# Build timestamps (MM:SS) for each selected frame index
timestamps = []
for idx in frames_idx:
sec = idx / native_fps
mm = int(sec // 60)
ss = int(sec % 60)
timestamps.append(f"{mm:02d}:{ss:02d}")
return frames_out, timestamps, duration_seconds
# Video Loader from sampled videos
##############################################################################
def load_image_directory_as_frames(
folder_path: str,
source_fps: float = 1.0,
target_fps: float = 1.0,
max_frames: int = 50,
) -> Tuple[List[Image.Image], List[str]]:
"""
Treats a directory of images as if they were consecutive frames in a
pseudo-video recorded at `source_fps`, then samples frames to achieve
an approximate `target_fps`, subject to a limit of `max_frames`.
Args:
folder_path (str): Directory path containing image frames (like "frame_001.jpg").
source_fps (float): The framerate at which these images were presumably captured.
target_fps (float): The approximate sampling rate we want in the output.
max_frames (int): Hard limit on how many frames we return.
Returns:
(frames, timestamps):
frames: List of loaded PIL.Image (RGB),
timestamps: Parallel list of "MM:SS" strings indicating each frame's approximate time.
Raises:
RuntimeError: If `folder_path` doesn't exist or has no valid images,
or if we fail to load any frames after sampling.
"""
if not os.path.isdir(folder_path):
raise RuntimeError(f"Path '{folder_path}' is not a directory.")
# 1) Gather potential image files
image_extensions = (".jpg", ".jpeg", ".png")
files = [f for f in os.listdir(folder_path) if f.lower().endswith(image_extensions)]
if not files:
raise RuntimeError(f"No image files found in directory '{folder_path}'.")
# 2) Extract numeric index from filenames, sort by (base, index)
pattern = re.compile(r"(.*?)[-_]?(\d+)\..*$")
numbered_files = []
for fname in files:
match = pattern.match(fname)
if match:
base, num_str = match.groups()
try:
num = int(num_str)
numbered_files.append((fname, base, num))
except ValueError:
pass # skip weird filenames
if not numbered_files:
raise RuntimeError(f"No valid numbered filenames found in '{folder_path}'.")
# Sort primarily by base name, then by the numeric portion
numbered_files.sort(key=lambda x: (x[1], x[2]))
sorted_files = [nf[0] for nf in numbered_files]
total_frames = len(sorted_files)
# If no frames => we raise an error
if total_frames == 0:
raise RuntimeError(f"Directory '{folder_path}' appears empty after sorting valid images.")
# 3) Compute the pseudo-video’s duration => total_frames / source_fps
# Then how many frames we want for target_fps => target_frames
duration_seconds = total_frames / float(source_fps or 1.0) # avoid dividing by 0
estimated_frames = int(round(target_fps * duration_seconds)) if target_fps > 0 else max_frames
desired_frames = min(estimated_frames, max_frames)
# 4) Generate the final list of indices
frame_indices = np.linspace(0, total_frames - 1, desired_frames, dtype=int).tolist()
# If after removing duplicates we have nothing => fallback to single frame?
if not frame_indices:
frame_indices = [0] # at least one
# 5) Load frames
frames = []
timestamps = []
for idx in frame_indices:
img_path = os.path.join(folder_path, sorted_files[idx])
sec = idx / float(source_fps)
mm = int(sec // 60)
ss = int(sec % 60)
try:
img = Image.open(img_path).convert("RGB")
frames.append(img)
timestamps.append(f"{mm:02d}:{ss:02d}")
except Exception as e:
logger.error(f"Failed to load image '{img_path}': {e}")
# We skip the broken image
continue
# If we ended up with zero loaded => raise
if not frames:
raise RuntimeError(f"No frames successfully loaded from '{folder_path}' after sampling.")
return frames, timestamps, duration_seconds
# Image Loader
##############################################################################
def load_single_image(img_path: str) -> Image.Image:
jpeg = Image.open(img_path)
img = jpeg.copy().convert('RGB')
return img
##############################################################################
# Helper Functions for Masking
##############################################################################
def find_global_img_patterns(tokens: List[str]) -> List[int]:
mask_positions = []
for i in range(len(tokens) - 4):
if (
tokens[i] == '<'
and tokens[i+1] == 'global'
and tokens[i+2] == '-'
and tokens[i+3] == 'img'
and tokens[i+4] == '>'
):
mask_positions.extend([i, i+1, i+2, i+3, i+4])
return mask_positions
def find_row_col_patterns(tokens: List[str]) -> List[int]:
pattern = re.compile(r'^< row _ [1-9] _ col _ [1-9] >$')
mask_positions = []
for i in range(len(tokens) - 8):
# Slice out exactly 9 tokens (e.g. <, row, _, 1, _, col, _, 1, >)
group = tokens[i : i + 9]
if pattern.fullmatch(" ".join(group)):
mask_positions.extend(range(i, i + 9))
return mask_positions
def _search_subsequence(
sequence: torch.Tensor,
pattern: List[int],
start: int = 0
) -> int:
"""
Searches for the first occurrence of 'pattern' in 'sequence'
starting at offset 'start'. Returns the index of that occurrence,
or -1 if not found.
"""
# Convert input_ids to a Python list
seq_list = sequence.tolist()
pat_len = len(pattern)
if pat_len == 0:
return -1
# Simple forward search
for i in range(start, len(seq_list) - pat_len + 1):
if seq_list[i : i + pat_len] == pattern:
return i
return -1
def _mask_system_tokens(
input_ids: torch.Tensor,
labels: torch.Tensor,
tokenizer
):
"""
Identifies every occurrence of "System:" in `input_ids` (tokenized form),
then masks (sets to IGNORE_INDEX) from the first token of "System:" up to
the next "<end_of_utterance>" marker or the end of the entire sequence.
Args:
input_ids (torch.Tensor): The token IDs for the conversation.
labels (torch.Tensor): A copy of `input_ids` that we modify in-place
to set certain spans to IGNORE_INDEX.
tokenizer: The tokenizer.
"""
system_str = "System:"
end_str = "<end_of_utterance>"
system_ids = tokenizer.encode(system_str, add_special_tokens=False)
end_ids = tokenizer.encode(end_str, add_special_tokens=False)
start_pos = 0
while True:
# 1) find next "System:"
sys_start = _search_subsequence(input_ids, system_ids, start=start_pos)
if sys_start == -1:
break # no more occurrences
# 2) find next "<end_of_utterance>" after that
sys_end = _search_subsequence(input_ids, end_ids, start=sys_start + len(system_ids))
if sys_end == -1:
sys_end = len(input_ids) # if not found, go to end of sequence
# 3) Mask [sys_start .. sys_end) in 'labels'
labels[sys_start:sys_end] = IGNORE_INDEX
# 4) Move forward
start_pos = sys_end + len(end_ids)
def _mask_user_tokens(
input_ids: torch.Tensor,
labels: torch.Tensor,
tokenizer
):
"""
Identifies every occurrence of "User:" in `input_ids`,
then masks (sets to IGNORE_INDEX) from that token to the next "<end_of_utterance>"
or the end of the sequence. This removes the user's text from the training labels,
so the model won't try to predict user text.
Args:
input_ids (torch.Tensor): The token IDs for the conversation.
labels (torch.Tensor): A copy of `input_ids` that we modify in-place
to set certain spans to IGNORE_INDEX.
tokenizer: The tokenizer.
"""
user_str = "User:"
end_str = "<end_of_utterance>"
user_ids = tokenizer.encode(user_str, add_special_tokens=False)
end_ids = tokenizer.encode(end_str, add_special_tokens=False)
start_pos = 0
while True:
# 1) find next "User:"
usr_start = _search_subsequence(input_ids, user_ids, start=start_pos)
if usr_start == -1:
break # no more occurrences
# 2) find next "<end_of_utterance>" after that
usr_end = _search_subsequence(input_ids, end_ids, start=usr_start + len(user_ids))
if usr_end == -1:
usr_end = len(input_ids)
# 3) Mask [usr_start .. usr_end) in 'labels'
labels[usr_start:usr_end] = IGNORE_INDEX
# 4) Move forward
start_pos = usr_end + len(end_ids)
##############################################################################
# Dataset
##############################################################################
class SupervisedDataset(Dataset):
def __init__(
self,
dataset_args: Dict[str, Any],
processor: transformers.ProcessorMixin,
data_args: DataArguments,
training_args: TrainingArguments,
model_args: ModelArguments,
):
"""
A dataset class that loads text/images/multi-image/videos,
tokenizes them via `processor`, and optionally masks user/system text.
Args:
dataset_args (Dict[str, Any]): Info specifying the dataset path,
sampling_strategy, possibly "source_fps", etc.
processor (ProcessorMixin): Usually a multi-modal HF processor
that has a tokenizer + image_processor for vision.
data_args (DataArguments): Contains config like `mask_user_tokens`,
`mask_system_tokens`, `fps`, etc.
training_args (TrainingArguments): Possibly used for sampling or logging.
"""
super().__init__()
self.mask_user_tokens = getattr(data_args, "mask_user_tokens", False)
self.mask_system_tokens = getattr(data_args, "mask_system_tokens", True)
self.add_media_intro_outro = getattr(data_args, "add_media_intro_outro", False)
self.processor = processor
self.tokenizer = processor.tokenizer
self.data_args = data_args
self.training_args = training_args
#todo: verfiery that args get here
self.target_fps = getattr(model_args, "fps", 1.0) # CLIP sampling FPS
self.frames_per_clip = getattr(model_args, "frames_per_clip", 1.0) # NUMBER of frames/CLIP (to be averaged later)
self.max_frames = getattr(data_args, "max_frames", 25)
self.video_target_size = getattr(data_args, "video_target_size", 384)
self.image_target_size = getattr(data_args, "image_target_size", 1536)
self.data_folder = getattr(data_args, "data_folder", "")
subdir = dataset_args.get("path", "")
self.mm_path = os.path.join(self.data_folder, subdir)
self.name = dataset_args.get("name", "unnamed_dataset")
self.modality = dataset_args.get("modality", "unknown")
self.source_fps = dataset_args.get("source_fps", 1)
data_path = dataset_args["json_path"]
self.list_data_dict = self._load_data(data_path)
sampling_strategy = dataset_args.get("sampling_strategy", "all")
self._apply_sampling_strategy(sampling_strategy)
logger.info(
f"[SupervisedDataset: {self.name}] - Label Masking Logic. "
f"\nmask_user_tokens: {self.mask_user_tokens}, mask_system_tokens: {self.mask_system_tokens}\n"
)
logger.info(
f"[SupervisedDataset: {self.name}] Final dataset size: {len(self.list_data_dict)}\n"
f"Dataset Arguments - FPS: {self.target_fps}, "
f"Max Frames: {self.max_frames}, "
f"Video Target Size: {self.video_target_size}, "
f"Image Target Size: {self.image_target_size}"
)
def _load_data(self, json_path: str) -> List[Dict[str, Any]]:
if not os.path.isfile(json_path):
raise FileNotFoundError(f"File not found: {json_path}")
if json_path.endswith(".json"):
with open(json_path, "r") as f:
data = json.load(f)
elif json_path.endswith(".jsonl"):
data = []
with open(json_path, "r") as f:
for line in f:
data.append(json.loads(line.strip()))
else:
raise ValueError(f"Unsupported file format: {json_path}")
logger.info(f"[{self.name}] Loaded {len(data)} items from {json_path}")
return data
def _apply_sampling_strategy(self, strategy: str):
if strategy == "all":
return
if ":" not in strategy:
return
kind, amount_str = strategy.split(":")
total = len(self.list_data_dict)
if amount_str.endswith("%"):
pct = float(amount_str.strip("%"))
sampling_number = max(1, math.ceil(total * pct / 100.0))
else:
sampling_number = int(amount_str)
if kind == "first":
self.list_data_dict = self.list_data_dict[:sampling_number]
elif kind == "end":
self.list_data_dict = self.list_data_dict[-sampling_number:]
elif kind == "random":
random.seed(42)
random.shuffle(self.list_data_dict)
self.list_data_dict = self.list_data_dict[:sampling_number]
logger.info(f"[{self.name}] after subsampling '{strategy}': {len(self.list_data_dict)} remain.")
def __len__(self) -> int:
return len(self.list_data_dict)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
# TODO: define number of retries somewhere else
num_base_retries = 3
# try the current sample first
for attempt_idx in range(num_base_retries):
try:
sample = self._get_item(i)
return sample
except Exception as e:
# sleep 1s in case it is a cloud disk issue
print(f"[Try #{attempt_idx}] Failed to fetch sample {i}. Exception:", e)
time.sleep(1)
# try other samples, in case it is file corruption issue
for attempt_idx in range(num_base_retries):
try:
#next_index = min(i + 1, len(self.list_data_dict) - 1)
random.seed(42) # TODO: should we set this here, or is this global variable we set anyway? make sure this makes sense.
next_index = random.choice(range(len(self.list_data_dict)))
sample = self._get_item(next_index)
return sample
except Exception as e:
# no need to sleep
print(f"[Try other #{attempt_idx}] Failed to fetch sample {next_index}. Exception:", e)
pass
try:
sample = self._get_item(i)
return sample
except Exception as e:
raise e
def _get_item(self, idx: int) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[idx]
if isinstance(idx, int):
sources = [sources]
content_type = sources[0].get("type", self.modality).lower()
frames: List[Image.Image] = []
timestamps: List[str] = []
duration_seconds = None
if content_type == "video":
## load videos
#self.processor.image_processor.size = (self.video_target_size, self.video_target_size)
self.processor.image_processor.size = {"longest_edge": self.video_target_size}
self.processor.image_processor.do_resize = True
self.processor.image_processor.do_image_splitting = False
media = sources[0].get("video") or sources[0].get("image")
if media:
path = os.path.join(self.mm_path, media)
if os.path.isdir(path):
## TODO: can we simplify this logic??
frames, timestamps, duration_seconds = load_image_directory_as_frames(
folder_path=path,
source_fps=self.source_fps,
target_fps=self.target_fps,
max_frames=self.max_frames
)
else:
# I added skip secs, these are how meny seconds to skip at start/end of video before sampling frames. sometimes, these frames are very noisy so better to skip them.
#TODO: we should add this as data arg.
frames, timestamps, duration_seconds = load_video(
path,
max_frames=self.max_frames,
target_fps=self.target_fps,
skip_secs=1.0 # or data_args.skip_secs if you want
)
elif content_type == "image" or content_type == "multiimage":
## load images and multi-image
self.processor.image_processor.size = {"longest_edge": self.image_target_size}
self.processor.image_processor.do_resize = True
self.processor.image_processor.do_image_splitting = True
media = sources[0].get("image", False)
if media:
if isinstance(media, str):
media = [media]
paths = [os.path.join(self.mm_path, m) for m in media]
frames = [load_single_image(path) for path in paths]
else:
raise("No image found for sample")
else:
frames = None
conversations = copy.deepcopy([e["conversations"] for e in sources])
## get system message
system_message = DEFAULT_SYSTEM_MESSAGE
for k, v in sources[0].items():
if isinstance(k, str) and "system" in k.lower() and "message" in k.lower() and isinstance(v, str):
system_message = v
break
# Ensure each conversation has a system turn at index 0
for conv in conversations:
system_idx = next((i for i, t in enumerate(conv) if t.get("from", "").lower() == "system"), None)
if system_idx is not None:
# Move existing system turn to index 0
conv.insert(0, conv.pop(system_idx))
else:
# If no system turn, add one
conv.insert(0, {"from": "system", "value": system_message})
conversations = [[self._convert_llava_to_openai_format(turn) for turn in conversation] for conversation in conversations]
conversations = [self._replace_multimodal_tokens(conversation, content_type, frames, timestamps) for conversation in conversations]
if self.add_media_intro_outro:
for conv in conversations:
if content_type == "text":
continue
elif content_type == "image" or content_type == "multiimage":
if conv[1]['content'][0]['type'] == "image":
conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_IMAGE_INTRO})
elif content_type == "video":
if conv[1]['content'][0]['type'] == "image" or conv[1]['content'][0]['type'] == "text" and FRAME_TIMESTAMP_MESSAGE in conv[1]['content'][0]['text']:
#conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_VIDEO_INTRO})
conv[1]['content'].insert(0, {'type': 'text', 'text': DEFAULT_VIDEO_INTRO.format(frame_count=num2words(len(frames)), video_duration=str(datetime.timedelta(seconds=duration_seconds)))})
target_message_index = -1
last_image_index = -1
for i, message in enumerate(conv):
if 'content' in message:
for j, content in enumerate(message['content']):
if content.get('type') == 'image':
target_message_index = i
last_image_index = j
# If we found an image, insert the outro right after it in the content list
if target_message_index != -1 and last_image_index != -1:
conv[target_message_index]['content'].insert(last_image_index + 1,
{'type': 'text', 'text': DEFAULT_MEDIA_OUTTRO})
text_input = self.processor.apply_chat_template(conversations[0], add_generation_prompt=False)
encoded = self.processor(
text=text_input,
images=frames,
return_tensors="pt",
padding=False,
)
if encoded["input_ids"][0].size(0) > self.processor.tokenizer.model_max_length:
raise ValueError(f"Sequence length {encoded['input_ids'][0].size(0)} exceeds maximum {self.processor.tokenizer.model_max_length}")
# Each item is shape [1, seq_len]
input_ids = encoded["input_ids"][0]
attention_mask = encoded["attention_mask"][0]
# Start all labels as input_ids
labels = input_ids.clone()
self._mask_special_tokens(input_ids, labels)
if self.mask_system_tokens:
_mask_system_tokens(input_ids, labels, self.tokenizer)
if self.mask_user_tokens:
_mask_user_tokens(input_ids, labels, self.tokenizer)
out = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels,
}
if "pixel_values" in encoded:
out["pixel_values"] = encoded["pixel_values"][0]
return out
def _convert_llava_to_openai_format(self, llava_entry: Dict[str, str]) -> Dict[str, Any]:
role_map = {"human": "user", "gpt": "assistant", "assistant": "assistant", "system": "system"}
speaker = llava_entry.get("from", "human").lower()
role = role_map.get(speaker, "user")
#text_value = llava_entry["value"].replace(DATA_VIDEO_TOKEN, DATA_IMAGE_TOKEN)
text_value = llava_entry["value"]
# Only replace <video> with <image> in user messages
if role == "user":
text_value = text_value.replace(DATA_VIDEO_TOKEN, DATA_IMAGE_TOKEN)
else:
# For assistant messages, replace tags with descriptive text
text_value = text_value.replace(DATA_VIDEO_TOKEN, "video tag")
text_value = text_value.replace(DATA_IMAGE_TOKEN, "image tag")
# This regex splits on (<image>) but keeps it in the returned list
parts = re.split(fr'({DATA_IMAGE_TOKEN})', text_value)
content_list = []
for chunk in parts:
chunk = chunk.strip()
if not chunk:
continue
if chunk == "<image>":
content_list.append({"type": "image"})
else:
content_list.append({"type": "text", "text": chunk})
# Fallback if the text was empty or something went wrong
if not content_list:
content_list = [{"type": "text", "text": text_value}]
return {"role": role, "content": content_list}
def _replace_multimodal_tokens(
self,
conversation: List[Dict[str, Any]],
content_type: str,
frames: List[Image.Image],
timestamps: List[Optional[str]],
) -> List[Dict[str, Any]]:
"""
Post-processes a conversation to handle missing or expanded "image"/"video" tokens
based on the loaded frames. If there's no explicit placeholder but frames exist,
it injects them at the start of the user's first message. If there's exactly one
placeholder token, it replicates it for every frame. This ensures the user's text
references the correct # of frames.
#TODO: add logger.warning if no <image> placeholder!
Args:
conversation (List[Dict[str, Any]]): The conversation with "role" ("user"/"assistant")
and "content" (list of dict: {"type":..., "text":...}).
content_type (str): "image" or "video". If "video" and multiple frames exist,
we replicate placeholders for each frame.
frames (List[Image.Image]): The frames we loaded from the local path.
Could be 0 or more.
timestamps (List[Optional[str]]): Timestamps (like "00:05") for each frame.
If some are None, we fallback to `i / self.target_fps`.
Returns:
conversation: The same conversation structure, but updated content blocks
so that each frame is referenced if needed.
"""
frames_inserted = False
first_user_msg_processed = False
for msg in conversation:
# We only modify the user's messages
if msg["role"] != "user":
continue
# Check if there's an explicit image/video token in the user's content
has_image_token = any(
block["type"] in ("image", "video")
for block in msg["content"]
)
# If we haven't processed the user's first message yet and no placeholders exist,
# but we DO have frames, let's insert them at the beginning.
if not first_user_msg_processed and not has_image_token and frames:
if content_type == "image":
# For a single-image scenario, just insert 1 placeholder
msg["content"].insert(0, {"type": "image"})
frames_inserted = True # We won't re-insert later
elif content_type == "video":
# Possibly multiple frames to insert
new_blocks = []
for i, frame in enumerate(frames):
# Use the provided timestamps if available
if timestamps and i < len(timestamps) and timestamps[i] is not None:
ts_str = timestamps[i]
else:
# Fallback: approximate from i / fps
sec = i / self.target_fps
mins = int(sec // 60)
secs = int(sec % 60)
ts_str = f"{mins:02d}:{secs:02d}"
new_blocks.append({"type": "text", "text": f"{FRAME_TIMESTAMP_MESSAGE} {ts_str}:"})
new_blocks.append({"type": "image"})
# Prepend to the user content
msg["content"] = new_blocks + msg["content"]
frames_inserted = True
# Now we check placeholders inside the existing content to see if we want to expand them.
updated_content = []
for block in msg["content"]:
if content_type == "video" and (block["type"] in ("image", "video")):
# If there's a single "image"/"video" token, we can replicate it for each frame
# *if we haven't done so already*.
if not frames_inserted and frames:
for i, frame in enumerate(frames):
# Use an existing or fallback timestamp
if timestamps and i < len(timestamps) and timestamps[i] is not None:
ts_str = timestamps[i]
else:
sec = i / self.target_fps
mins = int(sec // 60)
secs = int(sec % 60)
ts_str = f"{mins:02d}:{secs:02d}"
updated_content.append({"type": "text", "text": f"Frame from {ts_str}:"})
updated_content.append({"type": "image"})
frames_inserted = True
# If we've already inserted frames, we can skip adding another placeholder
elif content_type == "image" and block["type"] == "image":
# For images, if we want exactly one placeholder, we keep it.
# NOTE: I am assuming that in multi-image datasets, there are the correct number of tokens. Otherwise, I don't know where to insert the image tokens.
#TODO: add warning message if multi image but not enough tokens.
updated_content.append({"type": "image"})
else:
# All text blocks or anything else are kept
updated_content.append(block)
msg["content"] = updated_content
first_user_msg_processed = True
return conversation
def _mask_special_tokens(self, input_ids: torch.Tensor, labels: torch.Tensor):
labels[input_ids == self.tokenizer.pad_token_id] = IGNORE_INDEX
if DEFAULT_IMAGE_TOKEN in self.tokenizer.additional_special_tokens:
image_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_IMAGE_TOKEN)
labels[input_ids == image_id] = IGNORE_INDEX
if DEFAULT_VIDEO_TOKEN in self.tokenizer.additional_special_tokens:
image_id = self.tokenizer.convert_tokens_to_ids(DEFAULT_VIDEO_TOKEN)
labels[input_ids == image_id] = IGNORE_INDEX
if '<global-img>' in self.tokenizer.get_vocab():
global_img_id = self.tokenizer.convert_tokens_to_ids('<global-img>')
labels[input_ids == global_img_id] = IGNORE_INDEX
image_patches = re.compile(r'<row_\d+_col_\d+>')
patch_tokens = [token for token in self.tokenizer.get_vocab() if image_patches.fullmatch(token)]
if len(patch_tokens) > 0:
row_token_ids = self.tokenizer.convert_tokens_to_ids(patch_tokens)
for token_id in row_token_ids:
labels[input_ids == token_id] = IGNORE_INDEX
# Possibly also ignore custom placeholders
tokens = self.tokenizer.convert_ids_to_tokens(input_ids)
positions_to_mask = find_global_img_patterns(tokens) + find_row_col_patterns(tokens)
if len(positions_to_mask) > 0:
logger.warn(f"found {len} global image + row col tokens not tokenized correctly!")
for pos in positions_to_mask:
labels[pos] = IGNORE_INDEX