def reset_weights()

in datasets/ClassPrioritySampler.py [0:0]


    def reset_weights(self, epoch):
        # If it is linear shifting 
        if not self.freeze:
            e = np.clip(epoch, 0, self.epochs-1)
            self.manual_weights = self.get_manual_weights(self.lams[e])
            # make sure 'self.fixed_scale > 0' and 'self.manual_as_backend = True' are 
            # mutually exclusive 
            if self.fixed_scale > 0:
                self.ptree.reset_fixed_weights(self.manual_weights, self.rescale)
            if self.manual_as_backend:
                self.update_backend_distribution(self.manual_weights)
        
        # If it is root decay
        if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0:
            if self.root_decay == 'exp':
                self.nroot *= 2
            elif self.root_decay == 'linear':
                self.nroot += 1
            elif self.root_decay == 'autoexp':
                # self.nroot *= self.decay_factor
                self.nroot = np.power(self.decay_factor, epoch)

            bw = self.get_balanced_weights(self.nroot)
            if self.manual_as_backend:
                self.update_backend_distribution(bw)
            else:
                self.ptree.reset_fixed_weights(bw)