in datasets/MixedPrioritizedSampler.py [0:0]
def reset_weights(self, epoch):
if not self.freeze and self.fixed_scale > 0:
if epoch >= self.epochs:
e = self.epochs - 1
elif epoch < 1:
e = 0
else:
e = epoch
self.manual_weights = self.get_manual_weights(self.lams[e])
self.ptree.reset_fixed_weights(self.manual_weights, self.rescale)
if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0:
if self.root_decay == 'exp':
self.nroot *= 2
elif self.root_decay == 'linear':
self.nroot += 1
elif self.root_decay == 'autoexp':
# self.nroot *= self.decay_factor
self.nroot = np.power(self.decay_factor, epoch)
bw = self.get_balanced_weights(self.nroot)
self.ptree.reset_fixed_weights(bw)