skills/retrieval_augmented_generation/evaluation/provider_retrieval.py (90 lines of code) (raw):
import json
import os
from typing import Callable, List, Dict, Any, Tuple, Set
from vectordb import VectorDB, SummaryIndexedVectorDB
from anthropic import Anthropic
# Initialize the VectorDB
db = VectorDB("anthropic_docs")
# Load the Anthropic documentation
with open('../data/anthropic_docs.json', 'r') as f:
anthropic_docs = json.load(f)
db.load_data(anthropic_docs)
def retrieve_base(query, options, context):
input_query = context['vars']['query']
results = db.search(input_query, k=3)
outputs = []
for result in results:
outputs.append(result['metadata']['chunk_link'])
print(outputs)
result = {"output": outputs}
return result
# Initialize the VectorDB
db_summary = SummaryIndexedVectorDB("anthropic_docs_summaries")
# Load the Anthropic documentation
with open("../data/anthropic_summary_indexed_docs.json", 'r') as f:
anthropic_docs_summaries = json.load(f)
db_summary.load_data(anthropic_docs_summaries)
def retrieve_level_two(query, options, context):
input_query = context['vars']['query']
results = db_summary.search(input_query, k=3)
outputs = []
for result in results:
outputs.append(result['metadata']['chunk_link'])
print(outputs)
result = {"output": outputs}
return result
def _rerank_results(query: str, results: List[Dict], k: int = 3) -> List[Dict]:
# Prepare the summaries with their indices
summaries = []
print(len(results))
for i, result in enumerate(results):
summary = "[{}] Document: {}".format(
i,
result['metadata']['chunk_heading'],
result['metadata']['summary']
)
summary += " \n {}".format(result['metadata']['text'])
summaries.append(summary)
# Join summaries with newlines
joined_summaries = "\n".join(summaries)
prompt = f"""
Query: {query}
You are about to be given a group of documents, each preceded by its index number in square brackets. Your task is to select the only {k} most relevant documents from the list to help us answer the query.
{joined_summaries}
Output only the indices of {k} most relevant documents in order of relevance, separated by commas, enclosed in XML tags here:
<relevant_indices>put the numbers of your indices here, seeparted by commas</relevant_indices>
"""
client = Anthropic(api_key=os.environ.get('ANTHROPIC_API_KEY'))
try:
response = client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=50,
messages=[{"role": "user", "content": prompt}, {"role": "assistant", "content": "<relevant_indices>"}],
temperature=0,
stop_sequences=["</relevant_indices>"]
)
# Extract the indices from the response
response_text = response.content[0].text.strip()
indices_str = response_text
relevant_indices = []
for idx in indices_str.split(','):
try:
relevant_indices.append(int(idx.strip()))
except ValueError:
continue # Skip invalid indices
print(indices_str)
print(relevant_indices)
# If we didn't get enough valid indices, fall back to the top k by original order
if len(relevant_indices) == 0:
relevant_indices = list(range(min(k, len(results))))
# Ensure we don't have out-of-range indices
relevant_indices = [idx for idx in relevant_indices if idx < len(results)]
# Return the reranked results
reranked_results = [results[idx] for idx in relevant_indices[:k]]
# Assign descending relevance scores
for i, result in enumerate(reranked_results):
result['relevance_score'] = 100 - i # Highest score is 100, decreasing by 1 for each rank
return reranked_results
except Exception as e:
print(f"An error occurred during reranking: {str(e)}")
# Fall back to returning the top k results without reranking
return results[:k]
# Initialize the VectorDB
db_rerank = SummaryIndexedVectorDB("anthropic_docs_summaries_rerank")
# Load the Anthropic documentation
with open("../data/anthropic_summary_indexed_docs.json", 'r') as f:
anthropic_docs_summaries = json.load(f)
db_rerank.load_data(anthropic_docs_summaries)
def retrieve_level_three(query, options, context):
# Step 1: Get initial results from the summary db
initial_results = db_rerank.search(query, k=20)
# Step 2: Re-rank results
reranked_results = _rerank_results(query, initial_results, k=3)
# Step 3: Generate new context string from re-ranked results
new_context = ""
for result in reranked_results:
chunk = result['metadata']
new_context += f"\n <document> \n {chunk['chunk_heading']}\n\n{chunk['text']} \n </document> \n"
outputs = []
for result in reranked_results:
outputs.append(result['metadata']['chunk_link'])
print(outputs)
result = {"output": outputs}
return result