in main.py [0:0]
def test(model, args):
outs = defaultdict(list)
if args.dataset_test in [datasets.Madonna, datasets.EVVE]:
outs["map"].append(
eval.map(
model,
args.dataset_test,
args,
"all",
query_expansion=args.query_expansion,
)
)
if args.dataset_test.is_localization and args.dataset_test != datasets.VCDB:
outs["loc_errors"].append(
eval.localization_errors(model, args.dataset_test, args, "all")
)
if args.dataset_test == datasets.VCDB:
probas, labels = eval.segment_pr(model, args.dataset_test, args, "all")
outs["precision_recall"].append([probas, labels])
precision, recall, _ = precision_recall_curve(labels, probas)
print("AUC test", auc(recall, precision))
max_f1 = np.max(2 * (precision * recall) / (precision + recall + 10e-8))
print("MaxF1 test", max_f1)
output_path = os.path.join(
args.output_dir, "tests_%s.pkl" % args.dataset_test.__name__
)
with open(output_path, "wb") as pfile:
pkl.dump(outs, pfile, pkl.HIGHEST_PROTOCOL)