ood/datasets.py (145 lines of code) (raw):

# Copyright (c) Alibaba, Inc. and its affiliates. import os, ast import numpy as np from PIL import Image import torch import torchvision from torch.utils.data import DataLoader, Subset, dataloader from torchvision import transforms class CLIPFeatDataset(torch.utils.data.Dataset): def __init__(self, data_dir, epoch=0, nshot=-1, split='train'): super().__init__() self.data_dir = data_dir self.nshot = nshot self.split = split self.load_data(epoch) self.epoch = epoch def __len__(self): return len(self.features) def __getitem__(self, idx): if self.paths is None: # for OOD return self.features[idx], self.targets[idx] else: output = { 'img': self.features[idx], "label": self.targets[idx], "domain": '', "impath": self.paths[idx], "index": idx } return output def load_data(self, epoch=None): if hasattr(self, 'epoch') and epoch == self.epoch: return prefix = ('' if epoch is None else f'ep{epoch}_') + self.split if self.nshot > 0: prefix = f'{self.nshot}shot_' + prefix self.features = torch.load(f'{self.data_dir}/{prefix}_image_features.pt', map_location='cpu') self.targets = torch.load(f'{self.data_dir}/{prefix}_labels.pt', map_location='cpu').long() path_file = f'{self.data_dir}/{prefix}_paths.txt' if os.path.exists(path_file): with open(path_file, 'r') as f: self.paths = f.read().splitlines() self.epoch = epoch if 'train' in prefix: print(f'\nLoaded dataset from epoch {epoch}\n') else: self.paths = None class LargeOODDataset(torchvision.datasets.ImageFolder): def __init__(self, root, id_name, ood_name, transform): ood_name_dict = {'texture': 'dtd', 'inaturalist': 'iNaturalist', 'places': 'Places', 'sun': 'SUN'} data_root = f'{root}/{ood_name_dict[ood_name]}' target_transform = lambda x: x * 0 - 1 # all -1 super().__init__(data_root, transform, target_transform=target_transform) print("LargeOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self))) class SemanticOODDataset(torchvision.datasets.ImageFolder): def __init__(self, root, id_name, ood_name, transform): assert 'imagenet' in id_name and ood_name in ['easy', 'rand', 'hard'] data_root = f'{root}/{id_name}/ood_{ood_name}' target_transform = lambda x: x * 0 - 1 # all -1 super().__init__(data_root, transform, target_transform=target_transform) print("SemanticOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self))) class ClassOODDataset(torchvision.datasets.ImageFolder): def __init__(self, root, id_name, ood_name, transform): assert 'imagenet' in id_name if 'severity' in ood_name: data_root = f'{root}/{id_name}/{ood_name}' else: assert 'imagenet' in ood_name and '-o' in ood_name data_root = f'{root}/../{ood_name}/images' target_transform = lambda x: x * 0 - 1 # all -1 super().__init__(data_root, transform, target_transform=target_transform) print("ClassOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self))) class SCOODDataset(torch.utils.data.Dataset): def __init__(self, root, id_name, ood_name, transform): super(SCOODDataset, self).__init__() assert id_name in ['cifar10', 'cifar100', 'imagenet', 'imagenet100'] if 'imagenet' in id_name: id_name = 'cifar10' if ood_name == 'cifar': assert 'cifar' in id_name if id_name == 'cifar10': ood_name = 'cifar100' else: ood_name = 'cifar10' imglist_path = os.path.join(root, 'data/imglist/benchmark_%s' % id_name, 'test_%s.txt' % ood_name) with open(imglist_path) as fp: self.imglist = fp.readlines() self.transform = transform self.root = root print("SCOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self.imglist))) def __len__(self): return len(self.imglist) def __getitem__(self, index): # parse the string in imglist file: line = self.imglist[index].strip("\n") tokens = line.split(" ", 1) image_name, extra_str = tokens[0], tokens[1] extras = ast.literal_eval(extra_str) sc_label = extras['sc_label'] # the ood label is here. -1 means ood. # read image according to image name: img_path = os.path.join(self.root, 'data', 'images', image_name) with open(img_path, 'rb') as f: img = Image.open(f).convert('RGB') if self.transform is not None: img = self.transform(img) return img, sc_label class TinyImages(torch.utils.data.Dataset): def __init__(self, root, transform): super(TinyImages, self).__init__() self.data = np.load(os.path.join(root, 'tinyimages80m', '300K_random_images.npy')) self.transform = transform print("TinyImages Contain {} images".format(len(self.data))) def __getitem__(self, index): img = self.data[index] img = Image.fromarray(img) if self.transform is not None: img = self.transform(img) return img, -1 # -1 is the class def __len__(self): return len(self.data) def tinyimages300k_dataloaders(num_samples=300000, train_batch_size=64, num_workers=8, data_root_path='/ssd1/haotao/datasets'): num_samples = int(num_samples) data_dir = os.path.join(data_root_path) train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) train_set = Subset(TinyImages(data_dir, train=True, transform=train_transform, download=True), list(range(num_samples))) train_loader = DataLoader(train_set, batch_size=train_batch_size, shuffle=True, num_workers=num_workers, drop_last=True, pin_memory=True) return train_loader class _RepeatSampler: """ Sampler that repeats forever Args: sampler (Sampler) """ def __init__(self, sampler): self.sampler = sampler def __iter__(self): while True: yield from iter(self.sampler) class InfiniteDataLoader(dataloader.DataLoader): """ Dataloader that reuses workers Uses same syntax as vanilla DataLoader """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler)) self.iterator = super().__iter__() def __len__(self): return len(self.batch_sampler.sampler) def __iter__(self): for _ in range(len(self)): yield next(self.iterator)