datasets/imagenetv2.py (31 lines of code) (raw):

import os from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase from dassl.utils import listdir_nohidden from .imagenet import ImageNet @DATASET_REGISTRY.register() class ImageNetV2(DatasetBase): """ImageNetV2. This dataset is used for testing only. """ dataset_dir = "imagenetv2" dataset_name = 'imagenetv2' def __init__(self, cfg): root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = os.path.join(root, self.dataset_dir) image_dir = "imagenetv2-matched-frequency-format-val" self.image_dir = os.path.join(self.dataset_dir, image_dir) text_file = os.path.join(self.dataset_dir, "classnames.txt") classnames = ImageNet.read_classnames(text_file) data = self.read_data(classnames) super().__init__(train_x=data, test=data) def read_data(self, classnames): image_dir = self.image_dir folders = list(classnames.keys()) items = [] for label in range(1000): class_dir = os.path.join(image_dir, str(label)) imnames = listdir_nohidden(class_dir) folder = folders[label] classname = classnames[folder] for imname in imnames: impath = os.path.join(class_dir, imname) item = Datum(impath=impath, label=label, classname=classname) items.append(item) return items