in datasets/oxford_pets.py [0:0]
def subsample_classes(*args, subsample="all"):
"""Divide classes into two groups. The first group
represents base classes while the second group represents
new classes.
Args:
args: a list of datasets, e.g. train, val and test.
subsample (str): what classes to subsample.
"""
assert subsample in ["all", "base", "new"]
if subsample == "all":
return args
dataset = args[0]
labels = set()
for item in dataset:
labels.add(item.label)
labels = list(labels)
labels.sort()
n = len(labels)
# Divide classes into two halves
m = math.ceil(n / 2)
print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
if subsample == "base":
selected = labels[:m] # take the first half
else:
selected = labels[m:] # take the second half
relabeler = {y: y_new for y_new, y in enumerate(selected)}
output = []
for dataset in args:
dataset_new = []
for item in dataset:
if item.label not in selected:
continue
item_new = Datum(
impath=item.impath,
label=relabeler[item.label],
classname=item.classname
)
dataset_new.append(item_new)
output.append(dataset_new)
return output