in quant/data/data_loaders.py [0:0]
def get_train_loader(self) -> DataLoader:
"""Get a PyTorch data loader for the training set."""
train_dir = Path(self.dataset_path) / self.train_split
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.4, 0.4, 0.4),
transforms.ToTensor(),
self.normalize,
]
),
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=self.train_batch_size,
shuffle=True,
num_workers=self.workers,
pin_memory=True,
)
return train_loader