skills/retrieval_augmented_generation/evaluation/eval_retrieval.py (108 lines of code) (raw):
from typing import Dict, Union, Any, List
import ast
def calculate_mrr(retrieved_links: List[str], correct_links) -> float:
for i, link in enumerate(retrieved_links, 1):
if link in correct_links:
return 1 / i
return 0
def evaluate_retrieval(retrieved_links, correct_links):
correct_links = ast.literal_eval(correct_links)
true_positives = len(set(retrieved_links) & set(correct_links))
precision = true_positives / len(retrieved_links) if retrieved_links else 0
recall = true_positives / len(correct_links) if correct_links else 0
f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
mrr= calculate_mrr(retrieved_links, correct_links)
return precision, recall, mrr, f1
def get_assert(output: str, context) -> Union[bool, float, Dict[str, Any]]:
correct_chunks = context['vars']['correct_chunks']
try:
precision, recall, mrr, f1 = evaluate_retrieval(output, correct_chunks)
metrics: Dict[str, float] = {}
metrics['precision'] = precision
metrics['recall'] = recall
metrics['f1'] = f1
metrics['mrr'] = mrr
print("METRICS")
print(metrics)
overall_score = True
if f1 < 0.3:
overall_score = False
return {
"pass": overall_score, #if f1 > 0.3 we will pass, otherwise fail
"score": f1,
"reason": f"Precision: {precision} \n Recall: {recall} \n F1 Score: {f1} \n MRR: {mrr}",
"componentResults": [
{
"pass": True,
"score": mrr,
"reason": f"MRR is {mrr}",
"named_scores": {
"MRR": mrr
}
},
{
"pass": True,
"score": precision,
"reason": f"Precision is {precision}",
"named_scores": {
"Precision": precision
}
},
{
"pass": True,
"score": recall,
"reason": f"Recall is {recall}",
"named_scores": {
"Recall": recall
}
},
{
"pass": True,
"score": f1,
"reason": f"F1 is {f1}",
"named_scores": {
"F1": f1
}
},
],
}
except Exception as e:
return {
"pass": False, #if f1 > 0.3 we will pass, otherwise fail
"score": f1,
"reason": f"Unexpected error: {str(e)}",
"componentResults": [
{
"pass": False,
"score": mrr,
"reason": f"Unexpected error: {str(e)}",
"named_scores": {
"MRR": mrr
}
},
{
"pass": False,
"score": precision,
"reason": f"Unexpected error: {str(e)}",
"named_scores": {
"Precision": precision
}
},
{
"pass": False,
"score": recall,
"reason": f"Unexpected error: {str(e)}",
"named_scores": {
"Recall": recall
}
},
{
"pass": False,
"score": f1,
"reason": f"Unexpected error: {str(e)}",
"named_scores": {
"F1": f1
}
},
],
}