in datasets/ClassPrioritySampler.py [0:0]
def update_weights(self, inds, weights, labels):
""" Update priority weights """
if not self.manual_only and self.pri_mode == 'train':
weights = np.clip(weights, 0, self.init_weight)
# Iterate over all classes in the batch
for l in np.unique(labels):
# Calculate per-class delta weights
example_inds = inds[labels==l]
last_weights = self.per_example_uni_weights[example_inds]
# delta = np.power(weights[labels==l], self.alpha) - \
# np.power(last_weights, self.alpha)
delta = weights[labels==l] - last_weights
delta = self.momentum * self.per_example_velocities[example_inds] + \
(1-self.momentum) * delta
# Update velocities
self.per_example_velocities[example_inds] = delta
# Update per-example weights
# self.per_example_uni_weights[example_inds] = weights[labels==l]
self.per_example_uni_weights[example_inds] += delta
# Sacle the delta
# (ie, the per-example weights both before and after update)
delta *= self.per_example_ratios[example_inds]
# Update tree
if self.alpha == 1:
self.ptree.update_delta(l, delta.sum())
else:
self.ptree.update(l, self.per_example_uni_weights[self.cls_idxs[l]].sum())