datasets/ImbalanceCIFAR.py (90 lines of code) (raw):
"""
Adopted from https://github.com/Megvii-Nanjing/BBN
Customized by Kaihua Tang
"""
import torchvision
import numpy as np
from PIL import Image
# CIFAR10:
# many: 0,1,2
# median: 3,4,5,6
# few: 7,8,9
class IMBALANCECIFAR10(torchvision.datasets.CIFAR10):
cls_num = 10
def __init__(self, train, transform, imbalance_ratio=0.01, root='/ssd1/haotao/datasets', imb_type='exp'):
super(IMBALANCECIFAR10, self).__init__(root, train, transform=transform, target_transform=None, download=True)
self.train = train
if self.train:
img_num_list = self.get_img_num_per_cls(self.cls_num, imb_type, imbalance_ratio)
self.gen_imbalanced_data(img_num_list)
self.labels = self.targets
# print("{} Mode: Contain {} images".format("train" if train else "test", len(self.data)))
def _get_class_dict(self):
class_dict = dict()
for i, anno in enumerate(self.get_annotations()):
cat_id = anno["category_id"]
if not cat_id in class_dict:
class_dict[cat_id] = []
class_dict[cat_id].append(i)
return class_dict
def get_img_num_per_cls(self, cls_num, imb_type, imb_factor):
img_max = len(self.data) / cls_num
img_num_per_cls = []
if imb_type == 'exp':
for cls_idx in range(cls_num):
num = img_max * (imb_factor**(cls_idx / (cls_num - 1.0)))
img_num_per_cls.append(int(num))
elif imb_type == 'step':
for cls_idx in range(cls_num // 2):
img_num_per_cls.append(int(img_max))
for cls_idx in range(cls_num // 2):
img_num_per_cls.append(int(img_max * imb_factor))
else:
img_num_per_cls.extend([int(img_max)] * cls_num)
self.img_num_per_cls = img_num_per_cls
return img_num_per_cls
def gen_imbalanced_data(self, img_num_per_cls):
new_data = []
new_targets = []
targets_np = np.array(self.targets, dtype=np.int64)
classes = np.unique(targets_np)
self.num_per_cls_dict = dict()
for the_class, the_img_num in zip(classes, img_num_per_cls):
self.num_per_cls_dict[the_class] = the_img_num
idx = np.where(targets_np == the_class)[0]
# np.random.shuffle(idx) # This is very problametic. Different runs are using different training samples! So I removed this line.
selec_idx = idx[:the_img_num]
new_data.append(self.data[selec_idx, ...])
new_targets.extend([the_class, ] * the_img_num)
new_data = np.vstack(new_data)
self.data = new_data
self.targets = new_targets
def __getitem__(self, index):
img, label = self.data[index], self.labels[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
label = self.target_transform(label)
return img, label
def __len__(self):
return len(self.labels)
def get_num_classes(self):
return self.cls_num
def get_annotations(self):
annos = []
for label in self.labels:
annos.append({'category_id': int(label)})
return annos
def get_cls_num_list(self):
cls_num_list = []
for i in range(self.cls_num):
cls_num_list.append(self.num_per_cls_dict[i])
return cls_num_list
class IMBALANCECIFAR100(IMBALANCECIFAR10):
"""`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
This is a subclass of the `CIFAR10` Dataset.
"""
cls_num = 100
base_folder = 'cifar-100-python'
url = "https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz"
filename = "cifar-100-python.tar.gz"
tgz_md5 = 'eb9058c3a382ffc7106e4002c42a8d85'
train_list = [
['train', '16019d7e3df5f24257cddd939b257f8d'],
]
test_list = [
['test', 'f0ef6b0ae62326f3e7ffdfab6717acfc'],
]
meta = {
'filename': 'meta',
'key': 'fine_label_names',
'md5': '7973b15100ade9c7d40fb424638fde48',
}