in data/advanced_datasets.py [0:0]
def __iter__(self) -> Iterator[dict]:
"""
Returns an iterator over the dataset that yields fixed-length sequences for training.
The iterator uses a producer-consumer pattern with a background thread to efficiently
pre-fetch and buffer samples. The producer thread continuously reads from the base
dataset and fills a queue, while the main thread consumes from the queue.
The dataset is automatically sharded across workers when using num_workers > 1.
Returns:
Iterator[dict]: An iterator that yields training samples with the following structure:
- input_ids: Tensor of token ids of shape (seq_length,)
- labels: Tensor of labels of shape (seq_length,)
- attention_mask: Tensor of attention mask of shape (seq_length,)
- images: List of processed image tensors
"""
worker_info = get_worker_info()
worker_id = worker_info.id if worker_info else 0
num_workers = worker_info.num_workers if worker_info else 1
def make_base_iterator():
"""Return a (sharded) iterator over the underlying dataset."""
all_indices = range(len(self.dataset))
# Shard the *indices* first, before any data is fetched.
if num_workers > 1:
worker_indices = itertools.islice(
all_indices, worker_id, None, num_workers
)
else:
worker_indices = all_indices
# Create an iterator that only calls __getitem__ for the assigned indices.
def sharded_item_iterator():
for idx in worker_indices:
yield self.dataset[idx]
return sharded_item_iterator()
queue: Queue = Queue(maxsize=self.queue_size)
producer = threading.Thread(
target=self._producer, args=(make_base_iterator, queue), daemon=True
)
producer.start()
while True:
sample = queue.get()
if sample is self._sentinel:
break
yield sample