in lib/metrics.py [0:0]
def evaluate(net, xq, xb, gt, quantizers, best_key, device=None,
trainset=None):
net.eval()
if device is None:
device = next(net.parameters()).device.type
xqt = forward_pass(net, sanitize(xq), device=device)
xbt = forward_pass(net, sanitize(xb), device=device)
if trainset is not None:
trainset = forward_pass(net, sanitize(trainset), device=device)
nq, d = xqt.shape
res = {}
score = 0
for quantizer in quantizers:
qt = getQuantizer(quantizer, d)
qt.train(trainset)
xbtq = qt(xbt)
if not qt.asymmetric:
xqtq = qt(xqt)
I = get_nearestneighbors(xqtq, xbtq, 100, device)
else:
I = get_nearestneighbors(xqt, xbtq, 100, device)
print("%s\t nbit=%3d: " % (quantizer, qt.bits), end=' ')
# compute 1-recall at ranks 1, 10, 100 (comparable with
# fig 5, left of the paper)
recalls = []
for rank in 1, 10, 100:
recall = (I[:, :rank] == gt[:, :1]).sum() / float(nq)
key = '%s,rank=%d' % (quantizer, rank)
if key == best_key:
score = recall
recalls.append(recall)
print('%.4f' % recall, end=" ")
res[quantizer] = recalls
print("")
return res, score