def get_balanced_weights()

in datasets/MixedPrioritizedSampler.py [0:0]


    def get_balanced_weights(self, nroot):
        """ Calculate normalized generalized balanced weights """

        cnts = self.cnts
        if nroot is None:
            # Real balanced sampling weights
            cls_ws = cnts.min() / cnts
        elif nroot >= 1:
            # Generalized balanced weights
            cls_ws = cnts / cnts.sum()
            cls_ws = np.power(cls_ws, 1./nroot) * cnts.sum()
            cls_ws = cls_ws / cnts
        else:
            raise NotImplementedError('root:{} not implemented'.format(nroot))

        # Get un-normalized weights
        balanced_weights = np.zeros(self.num_samples)    
        for ci in range(self.num_classes):
            balanced_weights[self.cls_idxs[ci]] = cls_ws[ci]

        # Normalization and rescale
        balanced_weights *= self.num_samples / balanced_weights.sum() * \
                            self.balance_scale
        return balanced_weights