def main()

in evaluate.py [0:0]


def main():

    parser = argparse.ArgumentParser()
    root_default = Path(__file__).parent / 'MSLS_sample'
    parser.add_argument('--prediction',
                        type=Path,
                        default=Path(__file__).parent / 'files' / 'example_msls_im2im_prediction.csv',
                        help='Path to the prediction to be evaluated')
    parser.add_argument('--msls-root',
                        type=Path,
                        default=root_default,
                        help='Path to MSLS containing the train_val and/or test directories')
    parser.add_argument('--threshold',
                        type=float,
                        default=25,
                        help='Positive distance threshold defining ground truth pairs')
    parser.add_argument('--cities',
                        type=str,
                        default='zurich',
                        help='Comma-separated list of cities to evaluate on.'
                             ' Leave blank to use the default validation set (sf,cph)')
    parser.add_argument('--task',
                        type=str,
                        default='im2im',
                        help='Task to evaluate on: '
                             '[im2im, seq2im, im2seq, seq2seq]')
    parser.add_argument('--seq-length',
                        type=int,
                        default=3,
                        help='Sequence length to evaluate on for seq2X and X2seq tasks')
    parser.add_argument('--subtask',
                        type=str,
                        default='all',
                        help='Subtask to evaluate on: '
                             '[all, s2w, w2s, o2n, n2o, d2n, n2d]')
    parser.add_argument('--output',
                        type=Path,
                        default=None,
                        help='Path to dump the metrics to')
    args = parser.parse_args()

    if not args.msls_root.exists():
        if args.msls_root == root_default:
            download_msls_sample(args.msls_root)
        else:
            print(args.msls_root, root_default)
            raise FileNotFoundError("Not found: {}".format(args.msls_root))

    # select for which ks to evaluate
    ks = [1, 5, 10, 20]
    if args.task == 'im2im' and args.seq_length > 1:
        print(f"Ignoring sequence length {args.seq_length} for the im2im task. (Setting to 1)")
        args.seq_length = 1

    dataset = MSLS(args.msls_root, cities = args.cities, mode = 'val', posDistThr = args.threshold,
                    task = args.task, seq_length = args.seq_length, subtask = args.subtask)

    # get query and positive image keys
    database_keys =  [','.join([bn(i)[:-4] for i in p.split(',')]) for p in dataset.dbImages]
    positive_keys = [[','.join([bn(i)[:-4] for i in p.split(',')]) for p in dataset.dbImages[pos]] for pos in dataset.pIdx]
    query_keys = [','.join([bn(i)[:-4] for i in p.split(',')]) for p in dataset.qImages[dataset.qIdx]]
    all_query_keys = [','.join([bn(i)[:-4] for i in p.split(',')]) for p in dataset.qImages]

    # create dummy predictions
    if not args.prediction.exists():
        create_dummy_predictions(args.prediction, dataset)

    # load prediction rankings
    predictions = np.loadtxt(args.prediction, ndmin=2, dtype=str)

    # Ensure that there is a prediction for each query image
    for k in query_keys:
        assert k in predictions[:, 0], "You didn't provide any predictions for image {}".format(k)

    # Ensure that all predictions are in database
    for i, k in enumerate(predictions[:, 1:]):
        missing_elem_in_database = np.in1d(k, database_keys, invert = True)
        if missing_elem_in_database.all():

            print("Some of your predictions are not in the database for the selected task {}".format(k[missing_elem_in_database]))
            print("This is probably because they are panorama images. They will be ignored in evaluation")

            # move missing elements to the last positions of prediction
            predictions[i, 1:] = np.concatenate([k[np.invert(missing_elem_in_database)], k[missing_elem_in_database]])

    # Ensure that all predictions are unique
    for k in range(len(query_keys)):
        assert len(predictions[k, 1:]) == len(np.unique(predictions[k, 1:])), "You have duplicate predictions for image {} at line {}".format(query_keys[k], k)

    # Ensure that all query images are unique
    assert len(predictions[:,0]) == len(np.unique(predictions[:,0])), "You have duplicate query images"

    # Check if there are predictions that don't correspond to any query images
    for i, k in enumerate(predictions[:, 0]):
        if k not in query_keys:
            if k in dataset.query_keys_with_no_match:
                pass
                #print(f"Ignoring predictions for {k}. It has no positive match in the database.")
            elif k in all_query_keys:
                # TODO keep track of these and only produce the appropriate error message
                print(f"Ignoring predictions for {k}. It is not part of the query keys."
                      f"Only keys in subtask_index.csv are used to evaluate.")
            else:
                print(f"Ignoring predictions for {k} at line {i}. It is not in the selected cities or is a panorama")
    predictions = np.array([l for l in predictions if l[0] in query_keys])

    # evaluate ranks
    metrics = eval(query_keys, positive_keys, predictions, ks=ks)

    f = open(args.output, 'a') if args.output else None
    # save metrics
    for metric in ['recall', 'map']:
        for k in ks:
            line =  '{}_{}@{}: {:.3f}'.format(args.subtask,
                                              metric,
                                              k,
                                              metrics['{}@{}'.format(metric, k)])
            print(line)
            if f:
                f.write(line + '\n')
    if f:
        f.close()