def reset_weights()

in datasets/MixedPrioritizedSampler.py [0:0]


    def reset_weights(self, epoch):
        if not self.freeze and self.fixed_scale > 0:
            if epoch >= self.epochs:
                e = self.epochs - 1
            elif epoch < 1:
                e = 0
            else:
                e = epoch
            self.manual_weights = self.get_manual_weights(self.lams[e])
            self.ptree.reset_fixed_weights(self.manual_weights, self.rescale)
        
        if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0:
            if self.root_decay == 'exp':
                self.nroot *= 2
            elif self.root_decay == 'linear':
                self.nroot += 1
            elif self.root_decay == 'autoexp':
                # self.nroot *= self.decay_factor
                self.nroot = np.power(self.decay_factor, epoch)

            bw = self.get_balanced_weights(self.nroot)
            self.ptree.reset_fixed_weights(bw)