def get_iterators()

in egg/zoo/objects_game/features.py [0:0]


    def get_iterators(self):
        if self.load_data_path:
            train, valid, test = self.load_data(self.load_data_path)
        else:  # if load_data_path wasn't given then I need to generate the tuple
            world_dim = reduce(lambda x, y: x * y, self.perceptual_dimensions)
            possible_tuples = compute_binomial(world_dim, self.n_distractors + 1)

            list_of_dim = [range(1, elem + 1) for elem in self.perceptual_dimensions]
            all_vectors = list(itertools.product(*list_of_dim))

            assert (
                self.train_samples > 0
                and self.validation_samples > 0
                and self.test_samples > 0
            ), "Train size, validation size and test size must all be greater than 0"
            assert (
                possible_tuples
                > self.train_samples + self.validation_samples + self.test_samples
            ), f"Not enough data for requested split sizes. Reduced split samples or increase perceptual_dimensions"
            train, valid, test = self.generate_tuples(data=all_vectors)

        assert (
            self.train_samples >= self.batch_size
            and self.validation_samples >= self.batch_size
            and self.test_samples >= self.batch_size
        ), "Batch size cannot be smaller than any split size"

        train_dataset = TupleDataset(*train)
        valid_dataset = TupleDataset(*valid)
        test_dataset = TupleDataset(*test)

        train_it = data.DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate,
            drop_last=True,
            shuffle=self.shuffle_train_data,
        )
        validation_it = data.DataLoader(
            valid_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate,
            drop_last=True,
        )
        test_it = data.DataLoader(
            test_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate,
            drop_last=True,
        )

        if self.dump_data_folder:
            self.dump_data_folder.mkdir(exist_ok=True)
            path = (
                self.dump_data_folder
                / f"{self.perceptual_dimensions}_{self.n_distractors}_distractors"
            )
            np.savez_compressed(
                path,
                train=train[0],
                train_labels=train[1],
                valid=valid[0],
                valid_labels=valid[1],
                test=test[0],
                test_labels=test[1],
                n_distractors=self.n_distractors,
            )

        return train_it, validation_it, test_it