in per_class_augmentation/data.py [0:0]
def _create_dataloader(self, stage: str, augmentations: transform_lib.Compose):
path = os.path.join(self.data_dir, stage)
if stage == "train":
shuffle = True
dataset = TopAugmentationsDataset(
path,
transform_dir=self.top_transforms_dir,
num_transforms=self.num_transforms,
similarity_type=self.similarity_type,
plus_standard_aug=self.plus_standard_aug,
standard_aug_before=self.standard_aug_before,
top_per_class=self.top_per_class,
top_transform_ranking=self.top_transform_ranking,
transform_prob=self.transform_prob,
min_prop_boosted_filter=self.min_prop_boosted_filter,
min_perc_change_per_class_filter=self.min_perc_change_per_class_filter,
)
else:
shuffle = False
dataset = torchvision.datasets.ImageFolder(path, augmentations)
data_loader = DataLoader(
dataset,
batch_size=self.batch_size,
pin_memory=True,
num_workers=self.num_workers,
shuffle=shuffle,
)
return data_loader