def __iter__()

in vision/m4/training/dataset.py [0:0]


    def __iter__(self):
        # Dummy dataset idx used for compatibility with CustomChainDataset
        dummy_dataset_idx = 0

        if self.rank is None or self.world_size is None:
            raise ValueError("rank and world_size must be provided")

        # Get worker indices & details for resuming...
        worker_indices, worker_id = self._get_worker_indices()
        num_worker_indices = len(worker_indices)

        # Set start idx of loop based on `self.worker_resume_idxs`
        map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {}))
        self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, map_start_idx]
        self.rng = np.random.default_rng(seed=self.rng_seed)
        for i in range(map_start_idx, num_worker_indices, self.mapper_batch_size):
            # Set seed for the worker according to worker index and the index and then reset it work
            # This needs to be done so that torch random crop is deterministic
            rng_state = torch.get_rng_state()
            torch.manual_seed(f"{self.seed}{worker_id}{i}")
            # Feed `worker_indices[i]` to mapper to ensure "deterministic randomness" that we don't have to track...
            curr_mapped_batch = self.mapper(
                self.dataset[worker_indices[i : i + self.mapper_batch_size]],
                prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i),
            )
            torch.set_rng_state(rng_state)
            keys = list(curr_mapped_batch.keys())
            overflow_batch_keys = overflow_batch.keys()

            # Check if overflow from previous batches is left, if yes, add it to the current batch
            # Specifically, we should prepend this overflow batch so as it goes out first and
            # current batch possibly becomes next overflow batch
            if len(overflow_batch_keys) > 0:
                if sorted(overflow_batch_keys) != sorted(keys):
                    raise ValueError(
                        "Overflow batch keys not equal to current keys. Make sure mapper is always returning"
                        "  dictionary with the same keys. "
                        f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}"
                    )
                else:
                    mapped_batch = {}

                    if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch:
                        total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0)
                        max_num_images = max(
                            overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        max_height = max(
                            overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        max_width = max(
                            overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
                        padded_pixel_attention_masks = torch.zeros(
                            total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
                        )

                        start = 0
                        for batch in [overflow_batch, curr_mapped_batch]:
                            if "pixel_values" not in batch:
                                continue
                            px = batch["pixel_values"]
                            px_attn_mask = batch["pixel_attention_mask"]
                            end = start + px.size(0)
                            padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px
                            padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask
                            start += px.size(0)

                        mapped_batch["pixel_values"] = padded_image_tensor.contiguous()
                        mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous()

                    for key in keys:
                        if key in ["pixel_values", "pixel_attention_mask"]:
                            continue
                        mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0)
                    previous_overflow_batch = copy.deepcopy(overflow_batch)
                    overflow_batch = {}
            else:
                previous_overflow_batch = {}
                mapped_batch = curr_mapped_batch

            first_key = keys[0]
            mapped_batch_length = len(mapped_batch[first_key])

            if self.shuffle_after_packing:
                indices = list(range(mapped_batch_length))
                self.rng.shuffle(indices)

                for key in mapped_batch.keys():
                    mapped_batch[key] = mapped_batch[key][indices, ...]

            if mapped_batch_length < self.batch_size:
                # We need to add more data to this batch to make it of size `self.batch_size`
                # Just setting mapped_batch to overflow_batch should be enough as the next iteration
                # will add more data to it
                overflow_batch = mapped_batch
            else:
                # Now, yield batches of size batch_size from the mapped batch
                for key_idx in range(0, mapped_batch_length, self.batch_size):
                    # Set "reproducible" randomness
                    self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx]
                    self.rng = np.random.default_rng(seed=self.rng_seed)

                    if i == map_start_idx and key_idx <= last_key_idx:
                        # Handle Resume (only for "first" loop iteration) advance random state until `last_key_idx`
                        self.rng.random()
                    else:
                        overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys}
                        if len(overflow_batch[first_key]) != self.batch_size:
                            # Last batch
                            break
                        else:
                            dataset_state = {
                                "worker_idx": worker_id,
                                "map_start_idx": i,
                                "last_key_idx": key_idx,
                                "previous_overflow_batch": previous_overflow_batch,
                            }
                            yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch
                            overflow_batch = {}