in data_utils/functions_bis.py [0:0]
def return_loader_and_sampler(args, traindir, valdir, return_train = True):
augmentations = return_augmentations_types(args)
if return_train:
train_dataset = MyImageFolder(
traindir, augmentations[args.augment_train])
else:
train_dataset = []
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
# per GPU for DistributedDataParallel
batch_size = int(args.batch_size / args.world_size)
print(f"batch size per GPU is {batch_size}")
else:
train_sampler = None
batch_size = args.batch_size
if return_train:
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=(train_sampler is None),
num_workers=args.workers,
sampler=train_sampler,
pin_memory=True,
)
else:
train_loader = None
print("Train loader initiated")
val_loader = torch.utils.data.DataLoader(
MyImageFolder(
valdir,
augmentations[args.augment_valid]),
batch_size=batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True,
)
print("Val loader initiated")
return train_loader, val_loader, train_sampler