Dassl.pytorch/dassl/data/datasets/base_dataset.py (145 lines of code) (raw):

import os import random import os.path as osp import tarfile import zipfile from collections import defaultdict import gdown from dassl.utils import check_isfile class Datum: """Data instance which defines the basic attributes. Args: impath (str): image path. label (int): class label. domain (int): domain label. classname (str): class name. """ def __init__(self, impath="", label=0, domain=0, classname=""): assert isinstance(impath, str) assert check_isfile(impath) self._impath = impath self._label = label self._domain = domain self._classname = classname @property def impath(self): return self._impath @property def label(self): return self._label @property def domain(self): return self._domain @property def classname(self): return self._classname class DatasetBase: """A unified dataset class for 1) domain adaptation 2) domain generalization 3) semi-supervised learning """ dataset_dir = "" # the directory where the dataset is stored domains = [] # string names of all domains def __init__(self, train_x=None, train_u=None, val=None, test=None): self._train_x = train_x # labeled training data self._train_u = train_u # unlabeled training data (optional) self._val = val # validation data (optional) self._test = test # test data self._num_classes = self.get_num_classes(train_x) self._lab2cname, self._classnames = self.get_lab2cname(train_x) @property def train_x(self): return self._train_x @property def train_u(self): return self._train_u @property def val(self): return self._val @property def test(self): return self._test @property def lab2cname(self): return self._lab2cname @property def classnames(self): return self._classnames @property def num_classes(self): return self._num_classes @staticmethod def get_num_classes(data_source): """Count number of classes. Args: data_source (list): a list of Datum objects. """ label_set = set() for item in data_source: label_set.add(item.label) return max(label_set) + 1 @staticmethod def get_lab2cname(data_source): """Get a label-to-classname mapping (dict). Args: data_source (list): a list of Datum objects. """ container = set() for item in data_source: container.add((item.label, item.classname)) mapping = {label: classname for label, classname in container} labels = list(mapping.keys()) labels.sort() classnames = [mapping[label] for label in labels] return mapping, classnames def check_input_domains(self, source_domains, target_domains): assert len(source_domains) > 0, "source_domains (list) is empty" assert len(target_domains) > 0, "target_domains (list) is empty" self.is_input_domain_valid(source_domains) self.is_input_domain_valid(target_domains) def is_input_domain_valid(self, input_domains): for domain in input_domains: if domain not in self.domains: raise ValueError( "Input domain must belong to {}, " "but got [{}]".format(self.domains, domain) ) def download_data(self, url, dst, from_gdrive=True): if not osp.exists(osp.dirname(dst)): os.makedirs(osp.dirname(dst)) if from_gdrive: gdown.download(url, dst, quiet=False) else: raise NotImplementedError print("Extracting file ...") if dst.endswith(".zip"): zip_ref = zipfile.ZipFile(dst, "r") zip_ref.extractall(osp.dirname(dst)) zip_ref.close() elif dst.endswith(".tar"): tar = tarfile.open(dst, "r:") tar.extractall(osp.dirname(dst)) tar.close() elif dst.endswith(".tar.gz"): tar = tarfile.open(dst, "r:gz") tar.extractall(osp.dirname(dst)) tar.close() else: raise NotImplementedError print("File extracted to {}".format(osp.dirname(dst))) def generate_fewshot_dataset( self, *data_sources, num_shots=-1, repeat=False ): """Generate a few-shot dataset (typically for the training set). This function is useful when one wants to evaluate a model in a few-shot learning setting where each class only contains a small number of images. Args: data_sources: each individual is a list containing Datum objects. num_shots (int): number of instances per class to sample. repeat (bool): repeat images if needed (default: False). """ if num_shots < 1: if len(data_sources) == 1: return data_sources[0] return data_sources print(f"Creating a {num_shots}-shot dataset") output = [] for data_source in data_sources: tracker = self.split_dataset_by_label(data_source) dataset = [] for label, items in tracker.items(): if len(items) >= num_shots: sampled_items = random.sample(items, num_shots) else: if repeat: sampled_items = random.choices(items, k=num_shots) else: sampled_items = items dataset.extend(sampled_items) output.append(dataset) if len(output) == 1: return output[0] return output def split_dataset_by_label(self, data_source): """Split a dataset, i.e. a list of Datum objects, into class-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.label].append(item) return output def split_dataset_by_domain(self, data_source): """Split a dataset, i.e. a list of Datum objects, into domain-specific groups stored in a dictionary. Args: data_source (list): a list of Datum objects. """ output = defaultdict(list) for item in data_source: output[item.domain].append(item) return output