in Dassl.pytorch/dassl/data/samplers.py [0:0]
def __iter__(self):
batch_idxs_dict = defaultdict(list)
for label in self.labels:
idxs = copy.deepcopy(self.index_dic[label])
if len(idxs) < self.n_ins:
idxs = np.random.choice(idxs, size=self.n_ins, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.n_ins:
batch_idxs_dict[label].append(batch_idxs)
batch_idxs = []
avai_labels = copy.deepcopy(self.labels)
final_idxs = []
while len(avai_labels) >= self.ncls_per_batch:
selected_labels = random.sample(avai_labels, self.ncls_per_batch)
for label in selected_labels:
batch_idxs = batch_idxs_dict[label].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[label]) == 0:
avai_labels.remove(label)
return iter(final_idxs)