in datasets/MixedPrioritizedSampler.py [0:0]
def __init__(self, dataset, balance_scale=1.0, fixed_scale=1.0,
lam=None, epochs=90, cycle=0, nroot=None, manual_only=False,
rescale=False, root_decay=None, decay_gap=30, ptype='score',
alpha=1.0):
"""
"""
self.dataset = dataset
self.balance_scale = balance_scale
self.fixed_scale = fixed_scale
self.epochs = epochs
self.lam = lam
self.cycle = cycle
self.nroot = nroot
self.rescale = rescale
self.manual_only = manual_only
self.root_decay = root_decay
self.decay_gap = decay_gap
self.ptype = ptype
self.num_samples = len(dataset)
self.alpha = alpha
# If using root_decay, reset relevent parameters
if self.root_decay in ['exp', 'linear', 'autoexp']:
self.lam = 1
self.manual_only = True
self.nroot = 1
if self.root_decay == 'autoexp':
self.decay_gap = 1
self.decay_factor = np.power(nroot, 1/(self.epochs-1))
else:
assert self.root_decay is None
assert self.nroot is None or self.nroot >= 2
print("====> Decay GAP: {}".format(self.decay_gap))
# Take care of lambdas
if self.lam is None:
self.freeze = False
if cycle == 0:
self.lams = np.linspace(0, 1, epochs)
elif cycle == 1:
self.lams = np.concatenate([np.linspace(0,1,epochs//3)] * 3)
elif cycle == 2:
self.lams = np.concatenate([np.linspace(0,1,epochs//3),
np.linspace(0,1,epochs//3)[::-1],
np.linspace(0,1,epochs//3)])
else:
raise NotImplementedError(
'cycle = {} not implemented'.format(cycle))
else:
self.lams = [self.lam]
self.freeze = True
# Get num of samples per class
self.cls_cnts = []
self.labels = labels = np.array(self.dataset.labels)
for l in np.unique(labels):
self.cls_cnts.append(np.sum(labels==l))
self.num_classes = len(self.cls_cnts)
self.cnts = np.array(self.cls_cnts).astype(float)
# Get per-class image indexes
self.cls_idxs = [[] for _ in range(self.num_classes)]
for i, label in enumerate(self.dataset.labels):
self.cls_idxs[label].append(i)
for ci in range(self.num_classes):
self.cls_idxs[ci] = np.array(self.cls_idxs[ci])
# Build balanced weights based on class counts
self.balanced_weights = self.get_balanced_weights(self.nroot)
self.manual_weights = self.get_manual_weights(self.lams[0])
# Setup priority tree
if self.ptype == 'score':
self.init_weight = 1.
elif self.ptype in ['CE', 'entropy']:
self.init_weight = 6.9
else:
raise NotImplementedError('ptype {} not implemented'.format(self.ptype))
if self.manual_only:
self.init_weight = 0.
self.init_weight = np.power(self.init_weight, self.alpha)
self.ptree = PriorityTree(self.num_samples, self.manual_weights,
fixed_scale=self.fixed_scale,
init_weight=self.init_weight)