in datasets/ClassAwareSampler.py [0:0]
def __init__(self, data_source, num_samples_cls=1,):
num_classes = len(np.unique(data_source.labels))
self.class_iter = RandomCycleIter(range(num_classes))
cls_data_list = [list() for _ in range(num_classes)]
for i, label in enumerate(data_source.labels):
cls_data_list[label].append(i)
self.data_iter_list = [RandomCycleIter(x) for x in cls_data_list]
self.num_samples = max([len(x) for x in cls_data_list]) * len(cls_data_list)
self.num_samples_cls = num_samples_cls