def filter_predictions()

in assets/training/model_evaluation/src/compute_metrics.py [0:0]


def filter_predictions(data, task_type, column_name):
    """Read Json file utility function.

    Args:
        data (_type_): _description_
        column_name (_type_, optional): _description_. Defaults to None.

    Returns:
        _type_: _description_
    """
    if task_type in [TASK.CHAT_COMPLETION]:
        # do not filter as these contains multiple prediction columns
        return data
    #  for Question-Answering checking for multiple columns in ground truth
    if task_type == TASK.QnA:
        if not column_name:
            column_name = constants.PREDICTIONS_COLUMN_NAME
        if isinstance(data[data.columns[0]][0], dict) and len(data[data.columns[0]][0].keys()) > 1:
            try:
                if isinstance(data, pd.DataFrame):
                    logger.warning("Multiple predictions are not supported for the \
                                   Question and Answering currently.\
                                   Considering only the first prediction in case of multiple values.")
                    data[data.columns[0]] = data[data.columns[0]].apply(
                        lambda x: x[column_name][0] if len(x[column_name]) > 0 else ""
                    )
            except Exception as e:
                message_kwargs = {"prediction_column": column_name}
                exception = get_azureml_exception(DataValidationException, InvalidPredictionColumnNameData, e,
                                                  **message_kwargs)
                log_traceback(exception, logger)
                raise exception
        if column_name in data.columns:
            if isinstance(data[column_name].iloc[0], list) or isinstance(data[column_name].iloc[0], np.ndarray):
                logger.warning("Multiple predictions are not supported for the Question and Answering currently.\
                                Considering only the first prediction in case of multiple values.")
                data[column_name] = data[column_name].apply(lambda x: x[0])
    if len(data.columns) > 1:
        if column_name and column_name in data.columns:
            logger.info("Multiple columns found. Picking only prediction column.")
            try:
                data = data[[column_name]]
            except Exception as e:
                message_kwargs = {"prediction_column": column_name}
                exception = get_azureml_exception(DataValidationException, InvalidPredictionColumnNameData, e,
                                                  **message_kwargs)
                log_traceback(exception, logger)
                # raise exception
                logger.info("Error in filtering prediction columns. Taking all into consideration.")
        else:
            logger.warning("Multiple columns found in predictions. Taking all into consideration.")
    return data