def __iter__()

in vision/m4/training/dataset.py [0:0]


    def __iter__(self):
        # Dummy dataset idx used for compatibility with CustomChainDataset
        dummy_dataset_idx = 0

        if self.rank is None or self.world_size is None:
            raise ValueError("rank and world_size must be provided")

        # Relic from previous implementation - but needed for rng seed
        worker_id, worker_total_num = self._get_worker_id_and_worker_total_num()

        # Relic from previous implementation - but needed for rng seed
        map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {}))

        # Relic from previous implementation - but needed for rng seed
        i = map_start_idx

        # Initialize rng_seed
        self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i]
        self.rng = np.random.default_rng(seed=self.rng_seed)
        while True:
            # Set seed for the worker according to worker index and the index and then reset it work
            # This needs to be done so that torch random crop is deterministic
            rng_state = torch.get_rng_state()
            torch.manual_seed(f"{self.seed}{worker_id}{i}")
            try:
                next_batch = next(self.dataset)
                i += 1
            except StopIteration:
                logger.info(
                    f"{self.dataset_name.name.lower()} has finished one epoch and is moving on to the next one."
                    f" (epoch={self.epoch} - rank={self.rank} - worker_id={worker_id})"
                )
                break
            curr_mapped_batch = self.mapper(
                next_batch,
                prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i),
            )
            torch.set_rng_state(rng_state)
            keys = list(curr_mapped_batch.keys())
            overflow_batch_keys = overflow_batch.keys()

            # Check if overflow from previous batches is left, if yes, add it to the current batch
            # Specifically, we should prepend this overflow batch so as it goes out first and
            # current batch possibly becomes next overflow batch
            if len(overflow_batch_keys) > 0:
                if sorted(overflow_batch_keys) != sorted(keys):
                    raise ValueError(
                        "Overflow batch keys not equal to current keys. Make sure mapper is always returning"
                        "  dictionary with the same keys. "
                        f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}"
                    )
                else:
                    mapped_batch = {}

                    if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch:
                        total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0)
                        max_num_images = max(
                            overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        max_height = max(
                            overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        max_width = max(
                            overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0,
                            curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0,
                        )
                        padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
                        padded_pixel_attention_masks = torch.zeros(
                            total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
                        )

                        start = 0
                        for batch in [overflow_batch, curr_mapped_batch]:
                            if "pixel_values" not in batch:
                                continue
                            px = batch["pixel_values"]
                            px_attn_mask = batch["pixel_attention_mask"]
                            end = start + px.size(0)
                            padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px
                            padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask
                            start += px.size(0)

                        mapped_batch["pixel_values"] = padded_image_tensor.contiguous()
                        mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous()

                    for key in keys:
                        if key in ["pixel_values", "pixel_attention_mask"]:
                            continue
                        mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0)
                    overflow_batch = {}
            else:
                mapped_batch = curr_mapped_batch

            first_key = keys[0]
            mapped_batch_length = len(mapped_batch[first_key])

            if self.shuffle_after_packing:
                indices = list(range(mapped_batch_length))
                self.rng.shuffle(indices)

                for key in mapped_batch.keys():
                    mapped_batch[key] = mapped_batch[key][indices, ...]

            if mapped_batch_length < self.batch_size:
                # We need to add more data to this batch to make it of size `self.batch_size`
                # Just setting mapped_batch to overflow_batch should be enough as the next iteration
                # will add more data to it
                overflow_batch = mapped_batch
            else:
                # Now, yield batches of size batch_size from the mapped batch
                for key_idx in range(0, mapped_batch_length, self.batch_size):
                    # Set "reproducible" randomness
                    self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx]
                    self.rng = np.random.default_rng(seed=self.rng_seed)

                    overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys}
                    if len(overflow_batch[first_key]) != self.batch_size:
                        # Last batch
                        break
                    else:
                        dataset_state = {
                            "worker_idx": worker_id,
                            "map_start_idx": i,
                            "last_key_idx": key_idx,
                            "previous_overflow_batch": {},
                        }
                        yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch
                        overflow_batch = {}