in egg/zoo/objects_game/features.py [0:0]
def get_iterators(self):
if self.load_data_path:
train, valid, test = self.load_data(self.load_data_path)
else: # if load_data_path wasn't given then I need to generate the tuple
world_dim = reduce(lambda x, y: x * y, self.perceptual_dimensions)
possible_tuples = compute_binomial(world_dim, self.n_distractors + 1)
list_of_dim = [range(1, elem + 1) for elem in self.perceptual_dimensions]
all_vectors = list(itertools.product(*list_of_dim))
assert (
self.train_samples > 0
and self.validation_samples > 0
and self.test_samples > 0
), "Train size, validation size and test size must all be greater than 0"
assert (
possible_tuples
> self.train_samples + self.validation_samples + self.test_samples
), f"Not enough data for requested split sizes. Reduced split samples or increase perceptual_dimensions"
train, valid, test = self.generate_tuples(data=all_vectors)
assert (
self.train_samples >= self.batch_size
and self.validation_samples >= self.batch_size
and self.test_samples >= self.batch_size
), "Batch size cannot be smaller than any split size"
train_dataset = TupleDataset(*train)
valid_dataset = TupleDataset(*valid)
test_dataset = TupleDataset(*test)
train_it = data.DataLoader(
train_dataset,
batch_size=self.batch_size,
collate_fn=self.collate,
drop_last=True,
shuffle=self.shuffle_train_data,
)
validation_it = data.DataLoader(
valid_dataset,
batch_size=self.batch_size,
collate_fn=self.collate,
drop_last=True,
)
test_it = data.DataLoader(
test_dataset,
batch_size=self.batch_size,
collate_fn=self.collate,
drop_last=True,
)
if self.dump_data_folder:
self.dump_data_folder.mkdir(exist_ok=True)
path = (
self.dump_data_folder
/ f"{self.perceptual_dimensions}_{self.n_distractors}_distractors"
)
np.savez_compressed(
path,
train=train[0],
train_labels=train[1],
valid=valid[0],
valid_labels=valid[1],
test=test[0],
test_labels=test[1],
n_distractors=self.n_distractors,
)
return train_it, validation_it, test_it