in datasets.py [0:0]
def build_dataset(is_train, args):
transform = build_transform(is_train, args)
if args.data_set == 'CIFAR10':
args.data_path = "/datasets01/cifar-pytorch/11222017/"
dataset = datasets.CIFAR10(args.data_path, train=is_train, transform=transform)
nb_classes = 10
if args.data_set == 'CIFAR100':
args.data_path = "/datasets01/cifar100/022818/data/"
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform)
nb_classes = 100
elif args.data_set == 'IMNET':
root = os.path.join(args.data_path, 'train' if is_train else 'val')
dataset = ImageNetDataset(root, transform=transform,
sampling_ratio= (args.sampling_ratio if is_train else 1.), nb_classes=args.nb_classes)
nb_classes = args.nb_classes if args.nb_classes is not None else 1000
elif args.data_set == 'INAT':
dataset = INatDataset(args.data_path, train=is_train, year=2018,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
elif args.data_set == 'INAT19':
args.data_path = "/datasets01/inaturalist/090619/"
dataset = INatDataset(args.data_path, train=is_train, year=2019,
category=args.inat_category, transform=transform)
nb_classes = dataset.nb_classes
return dataset, nb_classes