def update_weights()

in datasets/ClassPrioritySampler.py [0:0]


    def update_weights(self, inds, weights, labels):
        """ Update priority weights """
        if not self.manual_only and self.pri_mode == 'train':
            weights = np.clip(weights, 0, self.init_weight)

            # Iterate over all classes in the batch
            for l in np.unique(labels):
                # Calculate per-class delta weights
                example_inds = inds[labels==l]
                last_weights = self.per_example_uni_weights[example_inds]
                # delta = np.power(weights[labels==l], self.alpha) - \
                #         np.power(last_weights, self.alpha)
                delta = weights[labels==l] - last_weights
                delta = self.momentum * self.per_example_velocities[example_inds] + \
                        (1-self.momentum) * delta
                
                # Update velocities 
                self.per_example_velocities[example_inds] = delta
                # Update per-example weights 
                # self.per_example_uni_weights[example_inds] = weights[labels==l]
                self.per_example_uni_weights[example_inds] += delta

                # Sacle the delta 
                # (ie, the per-example weights both before and after update)
                delta *= self.per_example_ratios[example_inds]

                # Update tree
                if self.alpha == 1:
                    self.ptree.update_delta(l, delta.sum())
                else:
                    self.ptree.update(l, self.per_example_uni_weights[self.cls_idxs[l]].sum())