in datasets.py [0:0]
def subsample_(self, subsample_what):
perm = torch.randperm(len(self)).tolist()
if subsample_what == "groups":
min_size = min(list(self.group_sizes))
else:
min_size = min(list(self.class_sizes))
counts_g = [0] * self.nb_groups * self.nb_labels
counts_y = [0] * self.nb_labels
new_i = []
for p in perm:
y, g = self.y[self.i[p]], self.g[self.i[p]]
if (
subsample_what == "groups"
and counts_g[self.nb_groups * int(y) + int(g)] < min_size
) or (subsample_what == "classes" and counts_y[int(y)] < min_size):
counts_g[self.nb_groups * int(y) + int(g)] += 1
counts_y[int(y)] += 1
new_i.append(self.i[p])
self.i = new_i
self.count_groups()