in libs/solaris/eval/base.py [0:0]
def eval_iou_return_GDFs(self, miniou=0.5, iou_field_prefix='iou_score',
ground_truth_class_field='', calculate_class_scores=True,
class_list=['all']):
"""Evaluate IoU between the ground truth and proposals.
Arguments
---------
miniou : float, optional
Minimum intersection over union score to qualify as a successful
object detection event. Defaults to ``0.5``.
iou_field_prefix : str, optional
The name of the IoU score column in ``self.proposal_GDF``. Defaults
to ``"iou_score"``.
ground_truth_class_field : str, optional
The column in ``self.ground_truth_GDF`` that indicates the class of
each polygon. Required if using ``calculate_class_scores``.
calculate_class_scores : bool, optional
Should class-by-class scores be calculated? Defaults to ``True``.
class_list : list, optional
List of classes to be scored. Defaults to ``['all']`` (score all
classes).
Returns
-------
scoring_dict_list : list
list of score output dicts for each image in the ground
truth and evaluated image datasets. The dicts contain
the following keys: ::
('class_id', 'iou_field', 'TruePos', 'FalsePos', 'FalseNeg',
'Precision', 'Recall', 'F1Score')
True_Pos_gdf : gdf
A geodataframe containing only true positive predictions
False_Neg_gdf : gdf
A geodataframe containing only false negative predictions
False_Pos_gdf : gdf
A geodataframe containing only false positive predictions
"""
scoring_dict_list = []
if calculate_class_scores:
if not ground_truth_class_field:
raise ValueError('Must provide ground_truth_class_field if using calculate_class_scores.')
if class_list == ['all']:
class_list = list(
self.ground_truth_GDF[ground_truth_class_field].unique())
if not self.proposal_GDF.empty:
class_list.extend(
list(self.proposal_GDF['__max_conf_class'].unique()))
class_list = list(set(class_list))
for class_id in class_list:
iou_field = "{}_{}".format(iou_field_prefix, class_id)
if class_id != 'all': # this is probably unnecessary now
self.ground_truth_GDF_Edit = self.ground_truth_GDF[
self.ground_truth_GDF[
ground_truth_class_field] == class_id].copy(deep=True)
else:
self.ground_truth_GDF_Edit = self.ground_truth_GDF.copy(
deep=True)
for _, pred_row in tqdm(self.proposal_GDF.iterrows()):
if pred_row['__max_conf_class'] == class_id or class_id == 'all':
pred_poly = pred_row.geometry
iou_GDF = iou.calculate_iou(pred_poly,
self.ground_truth_GDF_Edit)
# Get max iou
if not iou_GDF.empty:
max_iou_row = iou_GDF.loc[iou_GDF['iou_score'].idxmax(
axis=0, skipna=True)]
if max_iou_row['iou_score'] > miniou:
self.proposal_GDF.loc[pred_row.name, iou_field] = max_iou_row['iou_score']
self.ground_truth_GDF_Edit = self.ground_truth_GDF_Edit.drop(max_iou_row.name, axis=0)
else:
self.proposal_GDF.loc[pred_row.name, iou_field] = 0
else:
self.proposal_GDF.loc[pred_row.name, iou_field] = 0
if self.proposal_GDF.empty:
TruePos = 0
FalsePos = 0
else:
try:
True_Pos_gdf = self.proposal_GDF[
self.proposal_GDF[iou_field] >= miniou]
TruePos = True_Pos_gdf.shape[0]
if TruePos == 0:
True_Pos_gdf = None
False_Pos_gdf = self.proposal_GDF[
self.proposal_GDF[iou_field] < miniou]
FalsePos = False_Pos_gdf.shape[0]
if FalsePos == 0:
False_Pos_gdf = None
except KeyError: # handle missing iou_field
print("iou field {} missing")
TruePos = 0
FalsePos = 0
False_Pos_gdf = None
True_Pos_gdf = None
# number of remaining rows in ground_truth_gdf_edit after removing
# matches is number of false negatives
False_Neg_gdf = self.ground_truth_GDF_Edit
FalseNeg = False_Neg_gdf.shape[0]
if FalseNeg == 0:
False_Neg_gdf = None
if float(TruePos + FalsePos) > 0:
Precision = TruePos / float(TruePos + FalsePos)
else:
Precision = 0
if float(TruePos + FalseNeg) > 0:
Recall = TruePos / float(TruePos + FalseNeg)
else:
Recall = 0
if Recall * Precision > 0:
F1Score = 2 * Precision * Recall / (Precision + Recall)
else:
F1Score = 0
score_calc = {'class_id': class_id,
'iou_field': iou_field,
'TruePos': TruePos,
'FalsePos': FalsePos,
'FalseNeg': FalseNeg,
'Precision': Precision,
'Recall': Recall,
'F1Score': F1Score
}
scoring_dict_list.append(score_calc)
return scoring_dict_list, True_Pos_gdf, False_Neg_gdf, False_Pos_gdf