in assets/training/model_evaluation/src/compute_metrics.py [0:0]
def filter_predictions(data, task_type, column_name):
"""Read Json file utility function.
Args:
data (_type_): _description_
column_name (_type_, optional): _description_. Defaults to None.
Returns:
_type_: _description_
"""
if task_type in [TASK.CHAT_COMPLETION]:
# do not filter as these contains multiple prediction columns
return data
# for Question-Answering checking for multiple columns in ground truth
if task_type == TASK.QnA:
if not column_name:
column_name = constants.PREDICTIONS_COLUMN_NAME
if isinstance(data[data.columns[0]][0], dict) and len(data[data.columns[0]][0].keys()) > 1:
try:
if isinstance(data, pd.DataFrame):
logger.warning("Multiple predictions are not supported for the \
Question and Answering currently.\
Considering only the first prediction in case of multiple values.")
data[data.columns[0]] = data[data.columns[0]].apply(
lambda x: x[column_name][0] if len(x[column_name]) > 0 else ""
)
except Exception as e:
message_kwargs = {"prediction_column": column_name}
exception = get_azureml_exception(DataValidationException, InvalidPredictionColumnNameData, e,
**message_kwargs)
log_traceback(exception, logger)
raise exception
if column_name in data.columns:
if isinstance(data[column_name].iloc[0], list) or isinstance(data[column_name].iloc[0], np.ndarray):
logger.warning("Multiple predictions are not supported for the Question and Answering currently.\
Considering only the first prediction in case of multiple values.")
data[column_name] = data[column_name].apply(lambda x: x[0])
if len(data.columns) > 1:
if column_name and column_name in data.columns:
logger.info("Multiple columns found. Picking only prediction column.")
try:
data = data[[column_name]]
except Exception as e:
message_kwargs = {"prediction_column": column_name}
exception = get_azureml_exception(DataValidationException, InvalidPredictionColumnNameData, e,
**message_kwargs)
log_traceback(exception, logger)
# raise exception
logger.info("Error in filtering prediction columns. Taking all into consideration.")
else:
logger.warning("Multiple columns found in predictions. Taking all into consideration.")
return data