in quant/data/data_loaders.py [0:0]
def get_train_loader(self) -> DataLoader:
"""Get a PyTorch data loader for the training set."""
transform_train = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(self.mean_val, self.std_val),
]
)
dataset_train = datasets.CIFAR100(
root=self.dataset_path,
train=True,
download=self.download,
transform=transform_train,
)
train_loader = torch.utils.data.DataLoader(
dataset_train,
batch_size=self.train_batch_size,
shuffle=True,
num_workers=self.workers,
pin_memory=True,
)
return train_loader