def image_loader()

in ppuda/vision/loader.py [0:0]


def image_loader(dataset='imagenet', data_dir='./data/', test=True, fine_tune=False,
                 batch_size=64, test_batch_size=64, num_workers=0,
                 cutout=False, cutout_length=16, noise=False,
                 seed=1111, load_train_anyway=False, n_shots=None):
    """

    :param dataset: image dataset: imagenet, cifar10, cifar100, etc.
    :param data_dir: location of the dataset
    :param test: True to load the test data for evaluation, False to load the validation data.
    :param fine_tune: True when fine-tuning an ImageNet model on CIFAR-10
    :param batch_size: training batch size of images
    :param test_batch_size: evaluation batch size of images
    :param num_workers: number of threads to load/preprocess images
    :param cutout: use Cutout for data augmentation from
                   "Terrance DeVries, Graham W. Taylor. Improved Regularization of Convolutional Neural Networks with Cutout. 2017."
    :param cutout_length: Cutout hyperparameter
    :param noise: evaluate on the images with added Gaussian noise
    :param seed: random seed to shuffle validation images on ImageNet
    :param load_train_anyway: load training images even when evaluating on test data (test=True)
    :param n_shots: the number of training images per class (only for CIFAR-10 and other torchvision datasets and when test=True)
    :return: training and evaluation torch DataLoaders and number of classes in the dataset
    """
    train_data = None

    if dataset.lower() == 'imagenet':
        train_transform, valid_transform = transforms_imagenet(noise=noise, cifar_style=False)
        imagenet_dir = os.path.join(data_dir, 'imagenet')

        if not test or load_train_anyway:
            train_data = ImageNetDataset(imagenet_dir, 'train', transform=train_transform, has_validation=not test)

        valid_data = ImageNetDataset(imagenet_dir, 'val', transform=valid_transform, has_validation=not test)

        shuffle_val = True  # to evaluate models with batch norm in the training mode (in case there is no running statistics)
        n_classes = len(valid_data.classes)
        generator = torch.Generator()
        generator.manual_seed(seed)  # to reproduce evaluation with shuffle=True on ImageNet

    else:
        dataset = dataset.upper()
        train_transform, valid_transform = transforms_cifar(cutout=cutout, cutout_length=cutout_length, noise=noise, sz=224 if fine_tune else 32)
        if test:
            valid_data = eval('{}(data_dir, train=False, download=True, transform=valid_transform)'.format(dataset))
            if load_train_anyway:
                train_data = eval('{}(data_dir, train=True, download=True, transform=train_transform)'.format(dataset))
                if n_shots is not None:
                    train_data = to_few_shot(train_data, n_shots=n_shots)
        else:
            if n_shots is not None:
                print('few shot regime is only supported for evaluation on the test data')
            # Held out 10% (e.g. 5000 images in case of CIFAR-10) of training data as the validation set
            train_data = eval('{}(data_dir, train=True, download=True, transform=train_transform)'.format(dataset))
            valid_data = eval('{}(data_dir, train=True, download=True, transform=valid_transform)'.format(dataset))
            n_all = len(train_data.targets)
            n_val = n_all // 10
            idx_train, idx_val = torch.split(torch.arange(n_all), [n_all - n_val, n_val])

            train_data.data = train_data.data[idx_train]
            train_data.targets = [train_data.targets[i] for i in idx_train]

            valid_data.data = valid_data.data[idx_val]
            valid_data.targets = [valid_data.targets[i] for i in idx_val]

        if train_data is not None:
            train_data.checksum = train_data.data.mean()
            train_data.num_examples = len(train_data.targets)

        shuffle_val = False
        n_classes = len(torch.unique(torch.tensor(valid_data.targets)))
        generator = None

        valid_data.checksum = valid_data.data.mean()
        valid_data.num_examples = len(valid_data.targets)

    print('loaded {}: {} classes, {} train samples (checksum={}), '
          '{} {} samples (checksum={:.3f})'.format(dataset,
                                                   n_classes,
                                                   train_data.num_examples if train_data else 'none',
                                                   ('%.3f' % train_data.checksum) if train_data else 'none',
                                                   valid_data.num_examples,
                                                   'test' if test else 'val',
                                                   valid_data.checksum))


    if train_data is None:
        train_loader = None
    else:
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True,
                                  pin_memory=True, num_workers=num_workers)

    valid_loader = DataLoader(valid_data, batch_size=test_batch_size, shuffle=shuffle_val,
                              pin_memory=True, num_workers=num_workers, generator=generator)

    return train_loader, valid_loader, n_classes