def get_balanced_weights()

in datasets/ClassPrioritySampler.py [0:0]


    def get_balanced_weights(self, nroot):
        """ Calculate normalized generalized balanced weights """

        cnts = self.cnts
        if nroot is None:
            # Real balanced sampling weights, each class has the same weights
            # Un-normalized !!!
            cls_ws = np.ones(len(cnts))
        elif nroot >= 1:
            # Generalized balanced weights
            # Un-normalized !!!
            cls_ws = cnts / cnts.sum()
            cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum()
            cls_ws = cls_ws
        else:
            raise NotImplementedError('root:{} not implemented'.format(nroot))

        # Get un-normalized weights
        balanced_weights = cls_ws

        # Normalization and rescale
        balanced_weights *= self.num_samples / balanced_weights.sum() * \
                            self.balance_scale
        return balanced_weights