def __init__()

in datasets/ClassPrioritySampler.py [0:0]


    def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0,
                 lam=None, epochs=90, cycle=0, nroot=None, manual_only=False,
                 rescale=False, root_decay=None, decay_gap=30, ptype='score',
                 pri_mode='train', momentum=0., alpha=1.0):
        """
        """
        self.dataset = dataset
        self.balance_scale = balance_scale
        self.fixed_scale = fixed_scale
        self.epochs = epochs
        self.lam = lam
        self.cycle = cycle
        self.nroot = nroot
        self.rescale = rescale
        self.manual_only = manual_only
        self.root_decay = root_decay
        self.decay_gap = decay_gap
        self.ptype = ptype
        self.pri_mode = pri_mode
        self.num_samples = len(dataset)
        self.manual_as_backend = False
        self.momentum = momentum
        self.alpha = alpha

        assert 0. <= self.momentum <= 1.0
        assert 0. <= self.alpha

        # Change the backend distribution of priority if needed
        if self.fixed_scale < 0:
            self.fixed_scale = 0
            self.manual_as_backend = True

        # If using root_decay, reset relevent parameters
        if self.root_decay in ['exp', 'linear', 'autoexp']:
            self.lam = 1
            self.manual_only = True
            self.nroot = 1
            if self.root_decay == 'autoexp':
                self.decay_gap = 1
                self.decay_factor = np.power(nroot, 1/(self.epochs-1))
        else:
            assert self.root_decay is None
            assert self.nroot is None or self.nroot > 1
        print("====> Decay GAP: {}".format(self.decay_gap))

        # Take care of lambdas
        self.freeze = True
        if self.lam is None:
            self.freeze = False
            if cycle == 0:
                self.lams = np.linspace(0, 1, epochs)
            elif cycle == 1:
                self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3)
            elif cycle == 2:
                self.lams = np.concatenate([np.linspace(0,1,epochs//3),
                                            np.linspace(0,1,epochs//3)[::-1],
                                            np.linspace(0,1,epochs//3)])
            else:
                raise NotImplementedError(
                    'cycle = {} not implemented'.format(cycle))
        else:
            self.lams = [self.lam]

        # Get num of samples per class
        self.cls_cnts = []
        self.labels = labels = np.array(self.dataset.labels)
        for l in np.unique(labels):
            self.cls_cnts.append(np.sum(labels==l))
        self.num_classes = len(self.cls_cnts)
        self.cnts = np.array(self.cls_cnts).astype(float)
        
        # Get per-class image indexes
        self.cls_idxs = [[] for _ in range(self.num_classes)]
        for i, label in enumerate(self.dataset.labels):
            self.cls_idxs[label].append(i)
        self.data_iter_list = [RandomCycleIter(x) for x in self.cls_idxs]
        for ci in range(self.num_classes):
            self.cls_idxs[ci] = np.array(self.cls_idxs[ci])
        
        # Build balanced weights based on class counts 
        self.balanced_weights = self.get_balanced_weights(self.nroot)
        self.uniform_weights = self.get_uniform_weights()
        self.manual_weights = self.get_manual_weights(self.lams[0])

        # back_weights = self.get_balanced_weights(1.5)
        back_weights = self.uniform_weights

        # Calculate priority ratios that reshape priority into target distribution
        self.per_cls_ratios = self.get_cls_ratios(
            self.manual_weights if self.manual_as_backend else back_weights)
        self.per_example_ratios = self.broadcast(self.per_cls_ratios)

        # Setup priority tree
        if self.ptype == 'score':
            self.init_weight = 1.
        elif self.ptype in ['CE', 'entropy']:
            self.init_weight = 6.9
        else:
            raise NotImplementedError('ptype {} not implemented'.format(self.ptype))
        if self.manual_only:
            self.init_weight = 0.
        self.per_example_uni_weights = np.ones(self.num_samples) * self.init_weight
        self.per_example_velocities = np.zeros(self.num_samples)
        # init_priorities = np.power(self.init_weight, self.alpha) \
        #                 * self.uniform_weights * self.per_cls_ratios
        init_priorities = self.init_weight * self.uniform_weights * self.per_cls_ratios
        self.ptree = PriorityTree(self.num_classes, init_priorities,
                                  self.manual_weights.copy(), fixed_scale=self.fixed_scale,
                                  alpha=self.alpha)