datasets/SCOODBenchmarkDataset.py (30 lines of code) (raw):
import os, ast
import numpy as np
from PIL import Image
import torch
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
class SCOODDataset(torch.utils.data.Dataset):
def __init__(self, root, id_name, ood_name, transform):
super(SCOODDataset, self).__init__()
assert id_name in ['cifar10', 'cifar100']
imglist_path = os.path.join(root, 'data/imglist/benchmark_%s' % id_name, 'test_%s.txt' % ood_name)
with open(imglist_path) as fp:
self.imglist = fp.readlines()
self.transform = transform
self.root = root
print("SCOODDataset (id %s, ood %s) Contain %d images" % (id_name, ood_name, len(self.imglist)))
def __len__(self):
return len(self.imglist)
def __getitem__(self, index):
# parse the string in imglist file:
line = self.imglist[index].strip("\n")
tokens = line.split(" ", 1)
image_name, extra_str = tokens[0], tokens[1]
extras = ast.literal_eval(extra_str)
sc_label = extras['sc_label'] # the ood label is here. -1 means ood.
# read image according to image name:
img_path = os.path.join(self.root, 'data', 'images', image_name)
with open(img_path, 'rb') as f:
img = Image.open(f).convert('RGB')
if self.transform is not None:
img = self.transform(img)
return img, sc_label