in low_shot.py [0:0]
def __init__(self, file_handle, base_classes, novel_classes, novel_idx, max_per_label=0, generator_fn=None, generator=None):
self.f = file_handle
self.all_feats_dset = self.f['all_feats']
all_labels_dset = self.f['all_labels']
self.all_labels = all_labels_dset[...]
#base class examples
self.base_class_ids = np.where(np.in1d(self.all_labels, base_classes))[0]
total = self.f['count'][0]
self.base_class_ids = self.base_class_ids[self.base_class_ids<total]
# novel class examples
novel_feats = self.all_feats_dset[novel_idx,:]
novel_labels = self.all_labels[novel_idx]
# hallucinate if needed
if max_per_label>0:
novel_feats, novel_labels = generator_fn(novel_feats, novel_labels, generator, max_per_label)
self.novel_feats = novel_feats
self.novel_labels = novel_labels
self.base_classes = base_classes
self.novel_classes = novel_classes
self.frac = float(len(base_classes)) / float(len(novel_classes)+len(base_classes))
self.all_classes = np.concatenate((base_classes, novel_classes))