data/advanced_datasets.py (165 lines of code) (raw):
import torch
from torch.utils.data import IterableDataset, get_worker_info
import threading
from queue import Queue
from typing import Iterator
import itertools
import random
random.seed(42) # Set the random seed to the meaning of life for good luck
class ConstantLengthDataset(IterableDataset):
def __init__(
self,
dataset,
infinite: bool = False,
max_sample_length: int = 1024,
seq_length: int = 1024,
num_of_sequences: int = 1024,
queue_size: int = 2048,
max_images_per_example: int = 4,
max_images_per_knapsack: int = 18,
):
self.dataset = dataset
self.max_sample_length = max_sample_length
self.seq_length = seq_length
self.max_length = seq_length * num_of_sequences
self.epoch = 0 # only advanced when infinite=True
self.infinite = infinite
self.queue_size = queue_size
self.max_images_per_example = max_images_per_example
self.max_images_per_knapsack = max_images_per_knapsack
self._sentinel = object()
self._average_length_per_sample = (
self.dataset.mp_image_token_length + 198
) # 198 is the average tokens for the cauldron dataset
def __len__(self):
return int(
len(self.dataset) * self._average_length_per_sample / self.seq_length
)
def __iter__(self) -> Iterator[dict]:
"""
Returns an iterator over the dataset that yields fixed-length sequences for training.
The iterator uses a producer-consumer pattern with a background thread to efficiently
pre-fetch and buffer samples. The producer thread continuously reads from the base
dataset and fills a queue, while the main thread consumes from the queue.
The dataset is automatically sharded across workers when using num_workers > 1.
Returns:
Iterator[dict]: An iterator that yields training samples with the following structure:
- input_ids: Tensor of token ids of shape (seq_length,)
- labels: Tensor of labels of shape (seq_length,)
- attention_mask: Tensor of attention mask of shape (seq_length,)
- images: List of processed image tensors
"""
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1
def make_base_iterator():
"""Return a (sharded) iterator over the underlying dataset."""
all_indices = range(len(self.dataset))
# Shard the *indices* first, before any data is fetched.
if num_workers > 1:
worker_indices = itertools.islice(
all_indices, worker_id, None, num_workers
)
else:
worker_indices = all_indices
# Create an iterator that only calls __getitem__ for the assigned indices.
def sharded_item_iterator():
for idx in worker_indices:
yield self.dataset[idx]
return sharded_item_iterator()
queue: Queue = Queue(maxsize=self.queue_size)
producer = threading.Thread(
target=self._producer, args=(make_base_iterator, queue), daemon=True
)
producer.start()
while True:
sample = queue.get()
if sample is self._sentinel:
break
yield sample
def _producer(
self,
make_iterator, # a zero-arg lambda that returns a fresh (possibly sharded) iterator
queue: Queue,
):
"""Runs in a separate daemon thread and keeps `queue` full."""
iterator = make_iterator()
more_examples = True
while more_examples:
# ------------- 1) pull raw samples until we have enough -------- #
buffer, buffer_len = [], 0
while buffer_len < self.max_length:
try:
sample = next(iterator)
except StopIteration:
if self.infinite:
iterator = make_iterator()
self.epoch += 1
print(f"Epoch {self.epoch} finished, restarting iterator")
continue
else:
more_examples = False
break
if len(sample["input_ids"]) >= self.max_sample_length:
continue # skip overly long samples
if len(sample["images"]) > self.max_images_per_example:
continue # skip samples that exceed the image constraint
sample["input_ids"] = torch.cat(
[
sample["input_ids"],
torch.tensor([self.dataset.tokenizer.pad_token_id]),
]
)
sample["attention_mask"] = torch.cat(
[sample["attention_mask"], torch.tensor([0])]
)
sample["labels"] = torch.cat([sample["labels"], torch.tensor([-100])])
buffer.append(sample)
buffer_len += len(sample["input_ids"])
if not buffer:
break # nothing left and not infinite
# ------------- 2) run greedy knapsack & pack groups ------------ #
groups = self._balanced_greedy_knapsack(
buffer,
self.seq_length,
delta=5,
max_images_per_knapsack=self.max_images_per_knapsack,
)
for g in groups:
packed = self._pack_one_group(g, buffer, self.seq_length)
# put blocks if queue is full.
queue.put(
{
"input_ids": packed[0],
"labels": packed[1],
"attention_mask": packed[2],
"images": packed[3],
}
)
# finished → unblock consumer
queue.put(self._sentinel)
def _balanced_greedy_knapsack(
self, buffer, L, delta=0, max_images_per_knapsack=None
):
# Extract lengths and image counts from buffer
lengths = [len(x["input_ids"]) for x in buffer]
image_counts = [len(x["images"]) for x in buffer]
# keep the position while sorting
items = sorted(
enumerate(zip(lengths, image_counts)), key=lambda x: x[1][0], reverse=True
)
min_knapsacks = (sum(lengths) + L - 1) // L + delta
knapsack_load = [0] * min_knapsacks
knapsack_image_counts = [0] * min_knapsacks
knapsack_groups = [[] for _ in range(min_knapsacks)]
for idx, (item_len, item_image_count) in items:
# Find a suitable knapsack that satisfies both length and image count constraints
suitable_knapsack = None
# First try to find a knapsack that can fit both constraints
for ks_id in sorted(
range(len(knapsack_load)), key=knapsack_load.__getitem__
):
length_fits = knapsack_load[ks_id] + item_len <= L
image_fits = (
max_images_per_knapsack is None
or knapsack_image_counts[ks_id] + item_image_count
<= max_images_per_knapsack
)
if length_fits and image_fits:
suitable_knapsack = ks_id
break
# If no existing knapsack can fit, create a new one
if suitable_knapsack is None:
suitable_knapsack = len(knapsack_load)
knapsack_load.append(0)
knapsack_image_counts.append(0)
knapsack_groups.append([])
knapsack_groups[suitable_knapsack].append(idx)
knapsack_load[suitable_knapsack] += item_len
knapsack_image_counts[suitable_knapsack] += item_image_count
# remove the completely empty bags that the +delta heuristic created
random.shuffle(knapsack_groups) # Knapsacks are semi-ordered after packing, thanks Luis for noticing!
return [g for g in knapsack_groups if g]
def _pack_one_group(self, group_indices, batch, max_len):
ids, lbl, am, ims = [], [], [], []
for i in group_indices:
ids.extend(batch[i]["input_ids"])
lbl.extend(batch[i]["labels"])
am.extend(batch[i]["attention_mask"])
ims.extend(batch[i]["images"])
# safety: assert we never overflow
if len(ids) > max_len:
raise ValueError(f"Packed length {len(ids)} > max_len {max_len}")
return torch.stack(ids), torch.stack(lbl), torch.stack(am), ims