in mapillary_sls/datasets/msls.py [0:0]
def update_subcache(self, net = None):
# reset triplets
self.triplets = []
# if there is no network associate to the cache, then we don't do any hard negative mining.
# Instead we just create som naive triplets based on distance.
if net is None:
qidxs = np.random.choice(len(self.qIdx), self.cached_queries, replace = False)
for q in qidxs:
# get query idx
qidx = self.qIdx[q]
# get positives
pidxs = self.pIdx[q]
# choose a random positive (within positive range (default 10 m))
pidx = np.random.choice(pidxs, size = 1)[0]
# get negatives
while True:
nidxs = np.random.choice(len(self.dbImages), size = self.nNeg)
# ensure that non of the choice negative images are within the negative range (default 25 m)
if sum(np.in1d(nidxs, self.nonNegIdx[q])) == 0:
break
# package the triplet and target
triplet = [qidx, pidx, *nidxs]
target = [-1, 1] + [0]*len(nidxs)
self.triplets.append((triplet, target))
# increment subset counter
self.current_subset += 1
return
# take n query images
qidxs = np.asarray(self.subcache_indices[self.current_subset])
# take their positive in the database
pidxs = np.unique([i for idx in self.pIdx[qidxs] for i in idx])
# take m = 5*cached_queries is number of negative images
nidxs = np.random.choice(len(self.dbImages), self.cached_negatives, replace=False)
# and make sure that there is no positives among them
nidxs = nidxs[np.in1d(nidxs, np.unique([i for idx in self.nonNegIdx[qidxs] for i in idx]), invert=True)]
# make dataloaders for query, positive and negative images
opt = {'batch_size': self.bs, 'shuffle': False, 'num_workers': self.threads, 'pin_memory': True}
qloader = torch.utils.data.DataLoader(ImagesFromList(self.qImages[qidxs], transform=self.transform),**opt)
ploader = torch.utils.data.DataLoader(ImagesFromList(self.dbImages[pidxs], transform=self.transform),**opt)
nloader = torch.utils.data.DataLoader(ImagesFromList(self.dbImages[nidxs], transform=self.transform),**opt)
# calculate their descriptors
net.eval()
with torch.no_grad():
# initialize descriptors
qvecs = torch.zeros(len(qidxs), net.meta['outputdim']).to(self.device)
pvecs = torch.zeros(len(pidxs), net.meta['outputdim']).to(self.device)
nvecs = torch.zeros(len(nidxs), net.meta['outputdim']).to(self.device)
bs = opt['batch_size']
# compute descriptors
for i, batch in tqdm(enumerate(qloader), desc = 'compute query descriptors'):
X, y = batch
qvecs[i*bs:(i+1)*bs, : ] = net(X.to(self.device)).data
for i, batch in tqdm(enumerate(ploader), desc = 'compute positive descriptors'):
X, y = batch
pvecs[i*bs:(i+1)*bs, :] = net(X.to(self.device)).data
for i, batch in tqdm(enumerate(nloader), desc = 'compute negative descriptors'):
X, y = batch
nvecs[i*bs:(i+1)*bs, :] = net(X.to(self.device)).data
print('>> Searching for hard negatives...')
# compute dot product scores and ranks on GPU
pScores = torch.mm(qvecs, pvecs.t())
pScores, pRanks = torch.sort(pScores, dim=1, descending=True)
# calculate distance between query and negatives
nScores = torch.mm(qvecs, nvecs.t())
nScores, nRanks = torch.sort(nScores, dim=1, descending=True)
# convert to cpu and numpy
pScores, pRanks = pScores.cpu().numpy(), pRanks.cpu().numpy()
nScores, nRanks = nScores.cpu().numpy(), nRanks.cpu().numpy()
# selection of hard triplets
for q in range(len(qidxs)):
qidx = qidxs[q]
# find positive idx for this query (cache idx domain)
cached_pidx = np.where(np.in1d(pidxs, self.pIdx[qidx]))
# find idx of positive idx in rank matrix (descending cache idx domain)
pidx = np.where(np.in1d(pRanks[q,:], cached_pidx))
# take the closest positve
dPos = pScores[q, pidx][0][0]
# get distances to all negatives
dNeg = nScores[q, :]
# how much are they violating
loss = dPos - dNeg + self.margin ** 0.5
violatingNeg = 0 < loss
# if less than nNeg are violating then skip this query
if np.sum(violatingNeg) <= self.nNeg: continue
# select hardest negatives
hardest_negIdx = np.argsort(loss)[:self.nNeg]
# select the hardest negatives
cached_hardestNeg = nRanks[q, hardest_negIdx]
# select the closest positive (back to cache idx domain)
cached_pidx = pRanks[q, pidx][0][0]
# transform back to original index (back to original idx domain)
qidx = self.qIdx[qidx]
pidx = pidxs[cached_pidx]
hardestNeg = nidxs[cached_hardestNeg]
# package the triplet and target
triplet = [qidx, pidx, *hardestNeg]
target = [-1, 1] + [0]*len(hardestNeg)
self.triplets.append((triplet, target))
# increment subset counter
self.current_subset += 1