in assets/training/model_evaluation/src/compute_metrics.py [0:0]
def filter_ground_truths(data, task_type, column_name=None, extra_y_test_cols=None, config=None):
"""Read Json file utility function.
Args:
data (_type_): _description_
column_name (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
multiple_ground_truth = False
if task_type in [TASK.IMAGE_INSTANCE_SEGMENTATION, TASK.IMAGE_OBJECT_DETECTION,
TASK.FORECASTING] or \
(task_type == TASK.TEXT_GENERATION and
config.get(constants.TextGenerationColumns.SUBTASKKEY, "") == constants.SubTask.CODEGENERATION):
# do not filter as these contains multiple required columns
multiple_ground_truth = True
return data, multiple_ground_truth
# for Question-Answering checking for multiple columns in ground truth
if task_type == TASK.QnA and column_name:
if isinstance(data[data.columns[0]][0], dict) and len(data[data.columns[0]][0].keys()) > 1:
try:
if isinstance(data, pd.DataFrame):
# logger.warning("Multiple ground truths are not supported for the \
# Question and Answering currently.\
# Considering only the first ground truth in case of multiple values.")
# data[data.columns[0]] = data[data.columns[0]].apply(
# lambda x: x[column_name][0] if len(x[column_name]) > 0 else ""
# )
multiple_ground_truth = True
except Exception as e:
exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnNameData, e)
log_traceback(exception, logger)
raise exception
if column_name in data.columns:
if isinstance(data[column_name].iloc[0], list) or isinstance(data[column_name].iloc[0], np.ndarray):
multiple_ground_truth = True
if len(data.columns) > 1:
label_cols = []
if column_name:
label_cols += [column_name]
if extra_y_test_cols:
label_cols += extra_y_test_cols
if len(label_cols) > 0:
if all(col in data.columns for col in label_cols):
data = data[label_cols]
else:
exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnNameData, None)
log_traceback(exception, logger)
raise exception
else:
logger.warning("Multiple columns found in ground truths. Taking all into consideration.")
# exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnName, None)
# log_traceback(exception, logger)
# raise exception
return data, multiple_ground_truth