def __init__()

in datasets/MixedPrioritizedSampler.py [0:0]


    def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0,
                 lam=None, epochs=90, cycle=0, nroot=None, manual_only=False,
                 rescale=False, root_decay=None, decay_gap=30, ptype='score',
                 alpha=1.0):
        """
        """
        self.dataset = dataset
        self.balance_scale = balance_scale
        self.fixed_scale = fixed_scale
        self.epochs = epochs
        self.lam = lam
        self.cycle = cycle
        self.nroot = nroot
        self.rescale = rescale
        self.manual_only = manual_only
        self.root_decay = root_decay
        self.decay_gap = decay_gap
        self.ptype = ptype
        self.num_samples = len(dataset)
        self.alpha = alpha

        # If using root_decay, reset relevent parameters
        if self.root_decay in ['exp', 'linear', 'autoexp']:
            self.lam = 1
            self.manual_only = True
            self.nroot = 1
            if self.root_decay == 'autoexp':
                self.decay_gap = 1
                self.decay_factor = np.power(nroot, 1/(self.epochs-1))
        else:
            assert self.root_decay is None
            assert self.nroot is None or self.nroot >= 2
        print("====> Decay GAP: {}".format(self.decay_gap))

        # Take care of lambdas
        if self.lam is None:
            self.freeze = False
            if cycle == 0:
                self.lams = np.linspace(0, 1, epochs)
            elif cycle == 1:
                self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3)
            elif cycle == 2:
                self.lams = np.concatenate([np.linspace(0,1,epochs//3),
                                            np.linspace(0,1,epochs//3)[::-1],
                                            np.linspace(0,1,epochs//3)])
            else:
                raise NotImplementedError(
                    'cycle = {} not implemented'.format(cycle))
        else:
            self.lams = [self.lam]
            self.freeze = True

        # Get num of samples per class
        self.cls_cnts = []
        self.labels = labels = np.array(self.dataset.labels)
        for l in np.unique(labels):
            self.cls_cnts.append(np.sum(labels==l))
        self.num_classes = len(self.cls_cnts)
        self.cnts = np.array(self.cls_cnts).astype(float)
        
        # Get per-class image indexes
        self.cls_idxs = [[] for _ in range(self.num_classes)]
        for i, label in enumerate(self.dataset.labels):
            self.cls_idxs[label].append(i)
        for ci in range(self.num_classes):
            self.cls_idxs[ci] = np.array(self.cls_idxs[ci])
        
        # Build balanced weights based on class counts 
        self.balanced_weights = self.get_balanced_weights(self.nroot)
        self.manual_weights = self.get_manual_weights(self.lams[0])

        # Setup priority tree
        if self.ptype == 'score':
            self.init_weight = 1.
        elif self.ptype in ['CE', 'entropy']:
            self.init_weight = 6.9
        else:
            raise NotImplementedError('ptype {} not implemented'.format(self.ptype))
        if self.manual_only:
            self.init_weight = 0.
        self.init_weight = np.power(self.init_weight, self.alpha)
        self.ptree = PriorityTree(self.num_samples, self.manual_weights,
                                  fixed_scale=self.fixed_scale,
                                  init_weight=self.init_weight)