vision/smolvlm2/smolvlm/datasets/builder.py (316 lines of code) (raw):
import os
import math
import random
import logging
import yaml
from typing import Dict, Any, List, Optional
from collections import defaultdict
import torch
from torch.utils.data import Dataset, ConcatDataset
from smolvlm.datasets.dataset import SupervisedDataset
from smolvlm.train.args import DataArguments, TrainingArguments, ModelArguments
from smolvlm.constants import IGNORE_INDEX
from smolvlm.utils import mprint # or your custom printing utility
from tabulate import tabulate
from transformers import ProcessorMixin
from torch.nn.utils.rnn import pad_sequence
logger = logging.getLogger(__name__)
##############################################################################
# Multi-subsequence "varlen" packing with integer-coded `subseq_ids`
##############################################################################
def len2weight(num_effective_tokens: int, mode: str) -> float:
"""
Returns the sub-sample weight given the sub-sequence length.
"""
if num_effective_tokens == 0:
return 0.0 # or skip
if mode == "token":
return 1.0 # no length-based weighting
elif mode == "sample":
# each sub-sample counts equally, so 1 / length
return 1.0 / num_effective_tokens
elif mode == "square":
# default in InternVL
return 1.0 / (num_effective_tokens**0.5)
else:
# 'none' or fallback
return 1.0
class PackedConcatDataset(ConcatDataset):
"""
Merges multiple short sub-samples from an underlying ConcatDataset into a
single “packed” sample, up to a `cutoff_len` tokens. Assigns integer-coded
sub-sequence IDs in `subseq_ids`:
1 => sub-sample #1
2 => sub-sample #2
...
so your collator can turn them into block diagonal (varlen) attention masks.
Each returned item from __getitem__ is:
{
"input_ids": (sum_of_sub_len,) int,
"labels": (sum_of_sub_len,) int,
"subseq_ids": (sum_of_sub_len,) int in [1..N],
"pixel_values": (sum_of_frames, 3, H, W) if images exist
}
We do NOT do final token-level padding to cutoff_len; we let the data collator
handle batch-level padding (less wasteful).
Attributes:
datasets: List of sub-datasets we are merging.
cutoff_len: Max tokens we want in a single “packed” sample.
pad_token_id: If needed for partial fix-ups.
packed_cursor: Tracks how many sub-samples we have consumed so far.
"""
def __init__(
self,
datasets: List,
data_args,
model_max_length: int = 2048
):
super().__init__(datasets)
self.data_args = data_args
self.cutoff_len = max(model_max_length, 1)
self.pad_token_id = getattr(data_args, "pad_token_id", 0)
self.packed_cursor = 0
logger.info(
f"[PackedConcatDataset] Using cutoff_len={self.cutoff_len}; "
f"we'll merge multiple sub-samples per item."
)
def __len__(self):
# Underlying total number of sub-samples
return super().__len__()
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""
We ignore 'idx' because each call to __getitem__ gets the “next” packed sample,
from self.packed_cursor onward.
"""
if self.packed_cursor >= len(self):
raise IndexError("No more sub-samples left to pack in PackedConcatDataset.")
# Accumulate sub-samples
chunk_input_ids = []
chunk_labels = []
chunk_subseq_ids = []
pixel_key = None
pixel_values_list = []
sub_seq_counter = 0
current_token_count = 0
while True:
if self.packed_cursor >= len(self):
break
sub_item = super().__getitem__(self.packed_cursor)
self.packed_cursor += 1
sub_len = sub_item["input_ids"].size(0)
if (current_token_count > 0) and (current_token_count + sub_len) > self.cutoff_len:
# Revert if we can't fit this sub-sample
self.packed_cursor -= 1
break
sub_seq_counter += 1
seq_id_tensor = torch.full(
(sub_len,),
fill_value=sub_seq_counter,
dtype=torch.long,
device=sub_item["input_ids"].device
)
chunk_input_ids.append(sub_item["input_ids"])
chunk_labels.append(sub_item["labels"])
chunk_subseq_ids.append(seq_id_tensor)
# If images are present
if "pixel_values" in sub_item:
pixel_key = "pixel_values"
pixel_values_list.append(sub_item["pixel_values"])
current_token_count += sub_len
print("[Sequence Packing] current num tokens:", current_token_count)
if current_token_count >= self.cutoff_len:
break
# Merge text
if len(chunk_input_ids) == 0:
return {
"input_ids": torch.tensor([], dtype=torch.long),
"labels": torch.tensor([], dtype=torch.long),
"attention_mask": torch.tensor([], dtype=torch.long),
}
merged_input_ids = torch.cat(chunk_input_ids, dim=0)
merged_labels = torch.cat(chunk_labels, dim=0)
merged_subseq_ids = torch.cat(chunk_subseq_ids, dim=0)
# Merge images along frame dimension if present
merged_pixel_values = None
if pixel_key and pixel_values_list:
merged_pixel_values = torch.cat(pixel_values_list, dim=0)
# shape => (f1+f2+..., 3, H, W)
loss_weight = torch.ones_like(merged_subseq_ids, dtype=torch.float32)
unique_ids = merged_subseq_ids.unique()
unique_ids = unique_ids[unique_ids > 0] # ignore pad=0
for sid in unique_ids.tolist():
mask = (merged_subseq_ids == sid)
num_eff = (merged_labels[mask] != IGNORE_INDEX).sum().item()
w = len2weight(num_eff, self.data_args.loss_reduction)
loss_weight[mask] = w
# Build final
out_dict = {
"input_ids": merged_input_ids,
"labels": merged_labels,
"attention_mask": merged_subseq_ids,
"loss_weight": loss_weight,
}
if merged_pixel_values is not None:
out_dict[pixel_key] = merged_pixel_values
return out_dict
# Varlen Collator (subseq_ids => block diagonal)
##############################################################################
def pad_sequence_varlen(
sequences: List[torch.Tensor],
batch_first: bool = True,
padding_value: float = 0.0
) -> torch.Tensor:
"""
Similar signature to torch.nn.utils.rnn.pad_sequence, but treats each Tensor
as a 1D sequence of any integer-coded or float tokens, and uses integer-coded
varlen semantics if desired.
If batch_first=True, returns (batch_size, seq_len).
Otherwise returns (seq_len, batch_size).
"""
if len(sequences) == 0:
# Return an empty tensor if no sequences
return torch.tensor([], dtype=torch.long)
max_len = max(seq.size(0) for seq in sequences)
batch_size = len(sequences)
device = sequences[0].device
dtype = sequences[0].dtype
if batch_first:
# Shape => (batch_size, max_len)
out = torch.full((batch_size, max_len), padding_value, device=device, dtype=dtype)
for i, seq in enumerate(sequences):
length = seq.size(0)
out[i, :length] = seq
else:
# Shape => (max_len, batch_size)
out = torch.full((max_len, batch_size), padding_value, device=device, dtype=dtype)
for i, seq in enumerate(sequences):
length = seq.size(0)
out[:length, i] = seq
return out
##############################################################################
# Standard Collator (0/1 attn mask)
##############################################################################
class DataCollatorForSupervisedDataset:
"""
Collates examples containing text-only or text+image/video data.
1) Text sequences (input_ids, attention_mask, labels) are padded to the maximum batch length.
- If model_max_length is set, we optionally truncate.
2) Pixel data (pixel_values, optional) is padded to (max_frames, 3, max_h, max_w).
3) Pixel-level mask (pixel_attention_mask):
- If provided in the example, we pad it accordingly.
- If not provided, we fill the valid image region with 1, the remainder with 0.
"""
def __init__(
self,
pad_token_id: int,
model_max_length: Optional[int] = None,
image_size: Optional[int] = None,
func_pad_sequence = pad_sequence,
):
self.pad_token_id = pad_token_id
self.ignore_index = IGNORE_INDEX
self.model_max_length = model_max_length
self.image_size = image_size
self.func_pad_sequence = func_pad_sequence
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
################################################################
# PART A: Pad the text data (input_ids, attention_mask, labels)
################################################################
attention_masks_list = []
for ex in examples:
# If "attention_mask" is missing, we generate it on the fly
if "attention_mask" in ex:
attention_masks_list.append(ex["attention_mask"])
else:
am = (ex["input_ids"] != self.pad_token_id).long()
attention_masks_list.append(am)
input_ids = self.func_pad_sequence(
[ex["input_ids"] for ex in examples],
batch_first=True,
padding_value=self.pad_token_id
)
attention_mask = self.func_pad_sequence(
attention_masks_list,
batch_first=True,
padding_value=0
)
labels = self.func_pad_sequence(
[ex["labels"] for ex in examples],
batch_first=True,
padding_value=self.ignore_index
)
# Optional: truncate if model_max_length is specified
if self.model_max_length and input_ids.size(1) > self.model_max_length:
input_ids = input_ids[:, :self.model_max_length]
attention_mask = attention_mask[:, :self.model_max_length]
labels = labels[:, :self.model_max_length]
out = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"labels": labels
}
################################################################
# PART B: Handle pixel data (pixel_values) + pixel_attention_mask
################################################################
# Step 1: figure out maximum frames, height, width across the batch
pvs = [ex["pixel_values"] for ex in examples if "pixel_values" in ex]
if pvs: # there is at least one non-None pixel_values
max_frames = max(pv.shape[0] for pv in pvs)
max_h = max(pv.shape[-2] for pv in pvs)
max_w = max(pv.shape[-1] for pv in pvs)
else:
max_h = max_w = self.image_size
max_frames = 1 #TODO: verify this is good default
# Step 2: create padded pixel_values and pixel_attention_mask for each example
padded_pixel_values_list = []
padded_pixel_mask_list = []
for ex in examples:
pv = ex.get("pixel_values", None)
pm = ex.get("pixel_attention_mask", None) # shape (f, h, w) if provided
if pv is None:
# text-only => fill pixel data + mask with zeros
shape_pv = (max_frames, 3, max_h, max_w)
shape_pm = (max_frames, max_h, max_w)
padded_pv = torch.zeros(shape_pv, dtype=torch.float32)
padded_pm = torch.zeros(shape_pm, dtype=torch.long)
else:
f, c, h, w = pv.shape
# Prepare final storage
padded_pv = torch.zeros(
(max_frames, c, max_h, max_w),
dtype=pv.dtype,
device=pv.device
)
padded_pm = torch.zeros(
(max_frames, max_h, max_w),
dtype=torch.long,
device=pv.device
)
padded_pv[:f, :, :h, :w] = pv
# Copy or fill the pixel attention mask
if pm is not None:
padded_pm[:f, :h, :w] = pm
else:
# Mark valid region as 1
padded_pm[:f, :h, :w] = 1
padded_pixel_values_list.append(padded_pv)
padded_pixel_mask_list.append(padded_pm)
# Finally, stack along batch dimension
## try not outputting pixel_values in text-only sample
#if any("pixel_values" in ex for ex in examples):
out["pixel_values"] = torch.stack(padded_pixel_values_list, dim=0)
return out
##############################################################################
# Summaries
##############################################################################
def display_overview(summary_data: Dict[str, Any], total_count: int) -> None:
print("=== Overview ===")
print(f"Aggregate Sample Count: {total_count}\n")
for category, info in summary_data.items():
ctotal = info["total_samples"]
cpct = (ctotal / total_count * 100) if total_count > 0 else 0
print(f"{category.title()} Overview")
print(f"Number of Samples: {ctotal} ({cpct:.2f}%)")
print("-" * 50)
table_data = []
headers = ["Dataset", "Count", "Percentage"]
for entry in info["datasets"]:
esamples = entry["samples"]
epct = (esamples / total_count * 100) if total_count > 0 else 0
table_data.append([entry["dataset_name"], esamples, f"{epct:.2f}%"])
print(tabulate(table_data, headers=headers, tablefmt="fancy_grid"))
print()
##############################################################################
# Main builder logic
##############################################################################
def build_datasets(
data_args: DataArguments,
training_args: TrainingArguments,
model_args: ModelArguments,
processor: ProcessorMixin,
split: str = "train",
):
"""
1) Load a YAML describing multiple sub-datasets.
2) Create a list of SupervisedDataset objects.
3) If data_args.packed => use PackedConcatDataset (with subseq_ids),
else => normal ConcatDataset.
"""
if getattr(model_args, "frames_per_clip", 1) > 1:
from smolvlm.datasets.dataset_clip_sampling import SupervisedDataset
else:
from smolvlm.datasets.dataset import SupervisedDataset
mprint(f"[Dataset-INFO]: Loading from {data_args.data_mixture}")
with open(data_args.data_mixture, "r") as yf:
meta_datasets = yaml.safe_load(yf)
all_datasets = []
extra_info = []
for dataset_type, dataset_list in meta_datasets.items():
for ds_args in dataset_list:
ds = SupervisedDataset(
dataset_args=ds_args,
processor=processor,
data_args=data_args,
training_args=training_args,
model_args=model_args,
)
all_datasets.append(ds)
extra_info.append({
"dataset_name": ds.name,
"modality": ds.modality,
"samples": len(ds),
})
# Summaries
from collections import defaultdict
modality_summary = defaultdict(lambda: {"total_samples": 0, "datasets": []})
total_samples = 0
for entry in extra_info:
mod = entry["modality"]
dsname = entry["dataset_name"]
samples = entry["samples"]
modality_summary[mod]["total_samples"] += samples
modality_summary[mod]["datasets"].append(entry)
total_samples += samples
display_overview(modality_summary, total_samples)
# Build final dataset
if data_args.packed:
mprint("[build_datasets] Using PackedConcatDataset for multi-sample packing with subseq_ids.")
dataset = PackedConcatDataset(
all_datasets,
data_args=data_args,
model_max_length=training_args.model_max_length
)
else:
mprint("[build_datasets] Using standard ConcatDataset (no packing).")
dataset = ConcatDataset(all_datasets)
# Save some info in training_args
training_args.sample_lens = [e["samples"] for e in extra_info]
training_args.data_info = extra_info
return dataset
def make_supervised_data_module(
processor: ProcessorMixin,
data_args: DataArguments,
training_args: TrainingArguments,
model_args: ModelArguments,
):
"""
Creates train_dataset, eval_dataset, and data_collator.
If data_args.packed => we do integer-coded subseq_ids approach,
else => normal approach with 0/1 attention_mask.
"""
train_dataset = build_datasets(data_args, training_args, model_args, processor, split="train")
eval_dataset = None
if data_args.packed:
# Use the varlen collator
data_collator = DataCollatorForSupervisedDataset(
pad_token_id = processor.tokenizer.pad_token_id,
model_max_length = processor.tokenizer.model_max_length,
image_size = data_args.video_target_size,
func_pad_sequence = pad_sequence_varlen,
)
else:
# Use the normal collator
data_collator = DataCollatorForSupervisedDataset(
pad_token_id = processor.tokenizer.pad_token_id,
model_max_length = processor.tokenizer.model_max_length,
image_size = data_args.video_target_size,
func_pad_sequence = pad_sequence,
)
return dict(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)