def return_loader_and_sampler()

in data_utils/functions_bis.py [0:0]


def return_loader_and_sampler(args, traindir, valdir, return_train = True):

    augmentations = return_augmentations_types(args)

    if return_train:
        train_dataset = MyImageFolder(
                traindir, augmentations[args.augment_train])
    else:
        train_dataset = []

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
        # per GPU for DistributedDataParallel
        batch_size = int(args.batch_size / args.world_size)
        print(f"batch size per GPU is {batch_size}")
    else:
        train_sampler = None
        batch_size = args.batch_size

    if return_train:
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=(train_sampler is None),
            num_workers=args.workers,
            sampler=train_sampler,
            pin_memory=True,
        )
    else:
        train_loader = None
    print("Train loader initiated")
    val_loader = torch.utils.data.DataLoader(
        MyImageFolder(
            valdir,
            augmentations[args.augment_valid]),
        batch_size=batch_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
    )
    print("Val loader initiated")
    return train_loader, val_loader, train_sampler