def filter_ground_truths()

in assets/training/model_evaluation/src/compute_metrics.py [0:0]


def filter_ground_truths(data, task_type, column_name=None, extra_y_test_cols=None, config=None):
    """Read Json file utility function.

    Args:
        data (_type_): _description_
        column_name (_type_, optional): _description_. Defaults to None.

    Returns:
        _type_: _description_
    """
    multiple_ground_truth = False

    if task_type in [TASK.IMAGE_INSTANCE_SEGMENTATION, TASK.IMAGE_OBJECT_DETECTION,
                     TASK.FORECASTING] or \
            (task_type == TASK.TEXT_GENERATION and
             config.get(constants.TextGenerationColumns.SUBTASKKEY, "") == constants.SubTask.CODEGENERATION):
        # do not filter as these contains multiple required columns
        multiple_ground_truth = True
        return data, multiple_ground_truth
    #  for Question-Answering checking for multiple columns in ground truth
    if task_type == TASK.QnA and column_name:
        if isinstance(data[data.columns[0]][0], dict) and len(data[data.columns[0]][0].keys()) > 1:
            try:
                if isinstance(data, pd.DataFrame):
                    # logger.warning("Multiple ground truths are not supported for the \
                    #                Question and Answering currently.\
                    #                Considering only the first ground truth in case of multiple values.")
                    # data[data.columns[0]] = data[data.columns[0]].apply(
                    #     lambda x: x[column_name][0] if len(x[column_name]) > 0 else ""
                    # )
                    multiple_ground_truth = True
            except Exception as e:
                exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnNameData, e)
                log_traceback(exception, logger)
                raise exception
        if column_name in data.columns:
            if isinstance(data[column_name].iloc[0], list) or isinstance(data[column_name].iloc[0], np.ndarray):
                multiple_ground_truth = True

    if len(data.columns) > 1:
        label_cols = []
        if column_name:
            label_cols += [column_name]
        if extra_y_test_cols:
            label_cols += extra_y_test_cols
        if len(label_cols) > 0:
            if all(col in data.columns for col in label_cols):
                data = data[label_cols]
            else:
                exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnNameData, None)
                log_traceback(exception, logger)
                raise exception
        else:
            logger.warning("Multiple columns found in ground truths. Taking all into consideration.")
            # exception = get_azureml_exception(DataValidationException, InvalidGroundTruthColumnName, None)
            # log_traceback(exception, logger)
            # raise exception
    return data, multiple_ground_truth