in datasets/MixedPrioritizedSampler.py [0:0]
def get_balanced_weights(self, nroot):
""" Calculate normalized generalized balanced weights """
cnts = self.cnts
if nroot is None:
# Real balanced sampling weights
cls_ws = cnts.min() / cnts
elif nroot >= 1:
# Generalized balanced weights
cls_ws = cnts / cnts.sum()
cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum()
cls_ws = cls_ws / cnts
else:
raise NotImplementedError('root:{} not implemented'.format(nroot))
# Get un-normalized weights
balanced_weights = np.zeros(self.num_samples)
for ci in range(self.num_classes):
balanced_weights[self.cls_idxs[ci]] = cls_ws[ci]
# Normalization and rescale
balanced_weights *= self.num_samples / balanced_weights.sum() * \
self.balance_scale
return balanced_weights