Dassl.pytorch/dassl/data/datasets/da/digit5.py (71 lines of code) (raw):

import random import os.path as osp from dassl.utils import listdir_nohidden from ..build import DATASET_REGISTRY from ..base_dataset import Datum, DatasetBase # Folder names for train and test sets MNIST = {"train": "train_images", "test": "test_images"} MNIST_M = {"train": "train_images", "test": "test_images"} SVHN = {"train": "train_images", "test": "test_images"} SYN = {"train": "train_images", "test": "test_images"} USPS = {"train": "train_images", "test": "test_images"} def read_image_list(im_dir, n_max=None, n_repeat=None): items = [] for imname in listdir_nohidden(im_dir): imname_noext = osp.splitext(imname)[0] label = int(imname_noext.split("_")[1]) impath = osp.join(im_dir, imname) items.append((impath, label)) if n_max is not None: items = random.sample(items, n_max) if n_repeat is not None: items *= n_repeat return items def load_mnist(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_mnist_m(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, MNIST_M[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_svhn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SVHN[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_syn(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, SYN[split]) n_max = 25000 if split == "train" else 9000 return read_image_list(data_dir, n_max=n_max) def load_usps(dataset_dir, split="train"): data_dir = osp.join(dataset_dir, USPS[split]) n_repeat = 3 if split == "train" else None return read_image_list(data_dir, n_repeat=n_repeat) @DATASET_REGISTRY.register() class Digit5(DatasetBase): """Five digit datasets. It contains: - MNIST: hand-written digits. - MNIST-M: variant of MNIST with blended background. - SVHN: street view house number. - SYN: synthetic digits. - USPS: hand-written digits, slightly different from MNIST. For MNIST, MNIST-M, SVHN and SYN, we randomly sample 25,000 images from the training set and 9,000 images from the test set. For USPS which has only 9,298 images in total, we use the entire dataset but replicate its training set for 3 times so as to match the training set size of other domains. Reference: - Lecun et al. Gradient-based learning applied to document recognition. IEEE 1998. - Ganin et al. Domain-adversarial training of neural networks. JMLR 2016. - Netzer et al. Reading digits in natural images with unsupervised feature learning. NIPS-W 2011. """ dataset_dir = "digit5" domains = ["mnist", "mnist_m", "svhn", "syn", "usps"] def __init__(self, cfg): root = osp.abspath(osp.expanduser(cfg.DATASET.ROOT)) self.dataset_dir = osp.join(root, self.dataset_dir) self.check_input_domains( cfg.DATASET.SOURCE_DOMAINS, cfg.DATASET.TARGET_DOMAINS ) train_x = self._read_data(cfg.DATASET.SOURCE_DOMAINS, split="train") train_u = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="train") test = self._read_data(cfg.DATASET.TARGET_DOMAINS, split="test") super().__init__(train_x=train_x, train_u=train_u, test=test) def _read_data(self, input_domains, split="train"): items = [] for domain, dname in enumerate(input_domains): func = "load_" + dname domain_dir = osp.join(self.dataset_dir, dname) items_d = eval(func)(domain_dir, split=split) for impath, label in items_d: item = Datum( impath=impath, label=label, domain=domain, classname=str(label) ) items.append(item) return items