def eval_clears_mot()

in siammot/eval/eval_clears_mot.py [0:0]


def eval_clears_mot(samples, predicted_samples, data_filter_fn=None,
                    iou_thresh=0.5):
    """
    :param samples: a list of (sample_id, sample:DataSample)
    :param predicted_samples: a dict with (sample_id: predicted_tracks:DataSample)
    :param data_filter_fn: a callable function to filter entities
    :param iou_thresh: The IOU (between a predicted bounding box and gt ) threshold
                       that determines a predicted bounding box is a true positive
    """

    assert 0 < iou_thresh <= 1

    all_accumulators = []
    sample_ids = []

    metrics_host = mm.metrics.create()

    for (sample_id, sample) in tqdm(samples):

        predicted_tracks = predicted_samples[sample_id]
        num_frames = len(sample)

        def get_id_and_bbox(entities):
            ids = [entity.id for entity in entities]
            bboxes = [entity.bbox for entity in entities]
            return ids, bboxes

        accumulator = mm.MOTAccumulator(auto_id=True)
        for i in range(num_frames):
            valid_gt = sample.get_entities_for_frame_num(i)
            ignore_gt = []

            # If data filter function is available
            if data_filter_fn is not None:
                valid_gt, ignore_gt = data_filter_fn(valid_gt,
                                                     meta_data=sample.metadata)
            gt_ids, gt_bboxes = get_id_and_bbox(valid_gt)

            out_ids = []
            out_bboxes = []

            # if there is no annotation for a particular frame, we don't evaluate on it
            # this happens for low-fps annotation such as in CRP
            # if len(gt_bboxes) > 0:
            predicted_entities = predicted_tracks.get_entities_for_frame_num(i)

            # If data filter function is available
            if data_filter_fn is not None:
                valid_pred, ignore_pred = data_filter_fn(predicted_entities, ignore_gt)
            else:
                valid_pred = predicted_entities

            out_ids, out_bboxes = get_id_and_bbox(valid_pred)

            bbox_distances = mm.distances.iou_matrix(gt_bboxes, out_bboxes, max_iou=1-iou_thresh)
            accumulator.update(gt_ids, out_ids, bbox_distances)

        all_accumulators.append(accumulator)
        sample_ids.append(sample_id)

    # Make sure to update to the latest version of motmetrics via pip or idf1 calculation might be very slow
    metrics = ['num_frames', 'mostly_tracked', 'partially_tracked', 'mostly_lost', 'num_switches',
               'num_false_positives', 'num_misses', 'mota', 'motp', 'idf1']

    strsummary = ""
    if len(all_accumulators):
        summary = metrics_host.compute_many(
            all_accumulators,
            metrics=metrics,
            names=sample_ids,
            generate_overall=True
        )

        strsummary = mm.io.render_summary(
            summary,
            formatters=metrics_host.formatters,
            namemap=mm.io.motchallenge_metric_names
        )

    return all_accumulators, "\n\n"+strsummary+"\n\n"