in src/MaskRCNNDetection/maskrcnn_benchmark/data/samplers/grouped_batch_sampler.py [0:0]
def _prepare_batches(self):
dataset_size = len(self.group_ids)
# get the sampled indices from the sampler
sampled_ids = torch.as_tensor(list(self.sampler))
# potentially not all elements of the dataset were sampled
# by the sampler (e.g., DistributedSampler).
# construct a tensor which contains -1 if the element was
# not sampled, and a non-negative number indicating the
# order where the element was sampled.
# for example. if sampled_ids = [3, 1] and dataset_size = 5,
# the order is [-1, 1, -1, 0, -1]
order = torch.full((dataset_size,), -1, dtype=torch.int64)
order[sampled_ids] = torch.arange(len(sampled_ids))
# get a mask with the elements that were sampled
mask = order >= 0
# find the elements that belong to each individual cluster
clusters = [(self.group_ids == i) & mask for i in self.groups]
# get relative order of the elements inside each cluster
# that follows the order from the sampler
relative_order = [order[cluster] for cluster in clusters]
# with the relative order, find the absolute order in the
# sampled space
permutation_ids = [s[s.sort()[1]] for s in relative_order]
# permute each cluster so that they follow the order from
# the sampler
permuted_clusters = [sampled_ids[idx] for idx in permutation_ids]
# splits each cluster in batch_size, and merge as a list of tensors
splits = [c.split(self.batch_size) for c in permuted_clusters]
merged = tuple(itertools.chain.from_iterable(splits))
# now each batch internally has the right order, but
# they are grouped by clusters. Find the permutation between
# different batches that brings them as close as possible to
# the order that we have in the sampler. For that, we will consider the
# ordering as coming from the first element of each batch, and sort
# correspondingly
first_element_of_batch = [t[0].item() for t in merged]
# get and inverse mapping from sampled indices and the position where
# they occur (as returned by the sampler)
inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())}
# from the first element in each batch, get a relative ordering
first_index_of_batch = torch.as_tensor(
[inv_sampled_ids_map[s] for s in first_element_of_batch]
)
# permute the batches so that they approximately follow the order
# from the sampler
permutation_order = first_index_of_batch.sort(0)[1].tolist()
# finally, permute the batches
batches = [merged[i].tolist() for i in permutation_order]
if self.drop_uneven:
kept = []
for batch in batches:
if len(batch) == self.batch_size:
kept.append(batch)
batches = kept
return batches