Dassl.pytorch/dassl/data/datasets/dg/digit_single.py (64 lines of code) (raw):

import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase # Folder names for train and test sets MNIST = {"train": "train_images", "test": "test_images"} MNIST_M = {"train": "train_images", "test": "test_images"} SVHN = {"train": "train_images", "test": "test_images"} SYN = {"train": "train_images", "test": "test_images"} USPS = {"train": "train_images", "test": "test_images"} def read_image_list(im_dir, n_max=None, n_repeat=None): items = [] for imname in listdir_nohidden(im_dir): imname_noext = osp.splitext(imname)[0] label = int(imname_noext.split("_")[1]) impath = osp.join(im_dir, imname) items.append((impath, label)) if n_max is not None: # Note that the sampling process is NOT random, # which follows that in Volpi et al. NIPS'18. items = items[:n_max] if n_repeat is not None: items *= n_repeat return items def load_mnist(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_mnist_m(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST_M[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_svhn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SVHN[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_syn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SYN[split]) n_max = 10000 if split == "train" else None return read_image_list(data_dir, n_max=n_max) def load_usps(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, USPS[split]) return read_image_list(data_dir) @DATASET_REGISTRY.register() class DigitSingle(DatasetBase): """Digit recognition datasets for single-source domain generalization. There are five digit datasets: - MNIST: hand-written digits. - MNIST-M: variant of MNIST with blended background. - SVHN: street view house number. - SYN: synthetic digits. - USPS: hand-written digits, slightly different from MNIST. Protocol: Volpi et al. train a model using 10,000 images from MNIST and evaluate the model on the test split of the other four datasets. However, the code does not restrict you to only use MNIST as the source dataset. Instead, you can use any dataset as the source. But note that only 10,000 images will be sampled from the source dataset for training. Reference: - Lecun et al. Gradient-based learning applied to document recognition. IEEE 1998. - Ganin et al. Domain-adversarial training of neural networks. JMLR 2016. - Netzer et al. Reading digits in natural images with unsupervised feature learning. NIPS-W 2011. - Volpi et al. Generalizing to Unseen Domains via Adversarial Data Augmentation. NIPS 2018. """ # Reuse the digit-5 folder instead of creating a new folder dataset_dir = "digit5" domains = ["mnist", "mnist_m", "svhn", "syn", "usps"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") val = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="test") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train, val=val, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): func = "load_" + dname domain_dir = osp.join(self.dataset_dir, dname) items_d = eval(func)(domain_dir, split=split) for impath, label in items_d: item = Datum(impath=impath, label=label, domain=domain) items.append(item) return items