def eval_iou_return_GDFs()

in libs/solaris/eval/base.py [0:0]


    def eval_iou_return_GDFs(self, miniou=0.5, iou_field_prefix='iou_score',
                 ground_truth_class_field='', calculate_class_scores=True,
                 class_list=['all']):
        """Evaluate IoU between the ground truth and proposals.
        Arguments
        ---------
        miniou : float, optional
            Minimum intersection over union score to qualify as a successful
            object detection event. Defaults to ``0.5``.
        iou_field_prefix : str, optional
            The name of the IoU score column in ``self.proposal_GDF``. Defaults
            to ``"iou_score"``.
        ground_truth_class_field : str, optional
            The column in ``self.ground_truth_GDF`` that indicates the class of
            each polygon. Required if using ``calculate_class_scores``.
        calculate_class_scores : bool, optional
            Should class-by-class scores be calculated? Defaults to ``True``.
        class_list : list, optional
            List of classes to be scored. Defaults to ``['all']`` (score all
            classes).
        Returns
        -------
        scoring_dict_list : list
            list of score output dicts for each image in the ground
            truth and evaluated image datasets. The dicts contain
            the following keys: ::
                ('class_id', 'iou_field', 'TruePos', 'FalsePos', 'FalseNeg',
                'Precision', 'Recall', 'F1Score')
        True_Pos_gdf : gdf
            A geodataframe containing only true positive predictions
        False_Neg_gdf : gdf
            A geodataframe containing only false negative predictions
        False_Pos_gdf : gdf
            A geodataframe containing only false positive predictions
        """

        scoring_dict_list = []

        if calculate_class_scores:
            if not ground_truth_class_field:
                raise ValueError('Must provide ground_truth_class_field if using calculate_class_scores.')
            if class_list == ['all']:
                class_list = list(
                    self.ground_truth_GDF[ground_truth_class_field].unique())
                if not self.proposal_GDF.empty:
                    class_list.extend(
                        list(self.proposal_GDF['__max_conf_class'].unique()))
                class_list = list(set(class_list))

        for class_id in class_list:
            iou_field = "{}_{}".format(iou_field_prefix, class_id)
            if class_id != 'all':  # this is probably unnecessary now
                self.ground_truth_GDF_Edit = self.ground_truth_GDF[
                    self.ground_truth_GDF[
                        ground_truth_class_field] == class_id].copy(deep=True)
            else:
                self.ground_truth_GDF_Edit = self.ground_truth_GDF.copy(
                    deep=True)

            for _, pred_row in tqdm(self.proposal_GDF.iterrows()):
                if pred_row['__max_conf_class'] == class_id or class_id == 'all':
                    pred_poly = pred_row.geometry
                    iou_GDF = iou.calculate_iou(pred_poly,
                                                self.ground_truth_GDF_Edit)
                    # Get max iou
                    if not iou_GDF.empty:
                        max_iou_row = iou_GDF.loc[iou_GDF['iou_score'].idxmax(
                            axis=0, skipna=True)]
                        if max_iou_row['iou_score'] > miniou:
                            self.proposal_GDF.loc[pred_row.name, iou_field] = max_iou_row['iou_score']
                            self.ground_truth_GDF_Edit = self.ground_truth_GDF_Edit.drop(max_iou_row.name, axis=0)
                        else:
                            self.proposal_GDF.loc[pred_row.name, iou_field] = 0
                    else:
                        self.proposal_GDF.loc[pred_row.name, iou_field] = 0

            if self.proposal_GDF.empty:
                TruePos = 0
                FalsePos = 0
            else:
                try:
                    True_Pos_gdf = self.proposal_GDF[
                        self.proposal_GDF[iou_field] >= miniou]
                    TruePos = True_Pos_gdf.shape[0]
                    if TruePos == 0:
                        True_Pos_gdf = None
                    False_Pos_gdf = self.proposal_GDF[
                        self.proposal_GDF[iou_field] < miniou]
                    FalsePos = False_Pos_gdf.shape[0]
                    if FalsePos == 0:
                        False_Pos_gdf = None
                except KeyError:  # handle missing iou_field
                    print("iou field {} missing")
                    TruePos = 0
                    FalsePos = 0
                    False_Pos_gdf = None
                    True_Pos_gdf = None

            # number of remaining rows in ground_truth_gdf_edit after removing
            # matches is number of false negatives
            False_Neg_gdf = self.ground_truth_GDF_Edit
            FalseNeg = False_Neg_gdf.shape[0]
            if FalseNeg == 0:
                False_Neg_gdf = None
            if float(TruePos + FalsePos) > 0:
                Precision = TruePos / float(TruePos + FalsePos)
            else:
                Precision = 0
            if float(TruePos + FalseNeg) > 0:
                Recall = TruePos / float(TruePos + FalseNeg)
            else:
                Recall = 0
            if Recall * Precision > 0:
                F1Score = 2 * Precision * Recall / (Precision + Recall)
            else:
                F1Score = 0

            score_calc = {'class_id': class_id,
                          'iou_field': iou_field,
                          'TruePos': TruePos,
                          'FalsePos': FalsePos,
                          'FalseNeg': FalseNeg,
                          'Precision': Precision,
                          'Recall': Recall,
                          'F1Score': F1Score
                          }
            scoring_dict_list.append(score_calc)

        return scoring_dict_list, True_Pos_gdf, False_Neg_gdf, False_Pos_gdf