in datasets/ClassPrioritySampler.py [0:0]
def get_balanced_weights(self, nroot):
""" Calculate normalized generalized balanced weights """
cnts = self.cnts
if nroot is None:
# Real balanced sampling weights, each class has the same weights
# Un-normalized !!!
cls_ws = np.ones(len(cnts))
elif nroot >= 1:
# Generalized balanced weights
# Un-normalized !!!
cls_ws = cnts / cnts.sum()
cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum()
cls_ws = cls_ws
else:
raise NotImplementedError('root:{} not implemented'.format(nroot))
# Get un-normalized weights
balanced_weights = cls_ws
# Normalization and rescale
balanced_weights *= self.num_samples / balanced_weights.sum() * \
self.balance_scale
return balanced_weights