in datasets/ClassPrioritySampler.py [0:0]
def reset_weights(self, epoch):
# If it is linear shifting
if not self.freeze:
e = np.clip(epoch, 0, self.epochs-1)
self.manual_weights = self.get_manual_weights(self.lams[e])
# make sure 'self.fixed_scale > 0' and 'self.manual_as_backend = True' are
# mutually exclusive
if self.fixed_scale > 0:
self.ptree.reset_fixed_weights(self.manual_weights, self.rescale)
if self.manual_as_backend:
self.update_backend_distribution(self.manual_weights)
# If it is root decay
if self.root_decay in ['exp', 'linear', 'autoexp'] and epoch % self.decay_gap == 0:
if self.root_decay == 'exp':
self.nroot *= 2
elif self.root_decay == 'linear':
self.nroot += 1
elif self.root_decay == 'autoexp':
# self.nroot *= self.decay_factor
self.nroot = np.power(self.decay_factor, epoch)
bw = self.get_balanced_weights(self.nroot)
if self.manual_as_backend:
self.update_backend_distribution(bw)
else:
self.ptree.reset_fixed_weights(bw)