4-mmrag_tooluse/mmrag_bh.py (447 lines of code) (raw):
# %%
import os
import re
import base64
import io
import json
import logging
import argparse
from typing import Dict, List, Tuple, Any
from PIL import Image
import fitz # PyMuPDF
from concurrent.futures import ThreadPoolExecutor
from schema_definitions import schema_dict
from database import get_database_info
from config import TRIAGE_SYSTEM_PROMPT
import sqlite3
from openai import OpenAI
import qdrant_client
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import VectorParams, Distance
# %%
# docker run -d -p 6333:6333 qdrant/qdrant
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# %%
# Configuration
class Config:
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
QDRANT_HOST = 'localhost'
QDRANT_PORT = 6333
SLIDES_FOLDER = "./earnings_reports_sample"
TABLE_JSON_FOLDER = "./table_json"
BASE64_OUTPUT_FOLDER = "./base64_images"
COLLECTION_NAME = "image_embeddings"
EMBEDDING_MODEL = "text-embedding-3-small"
GPT_MODEL = "gpt-4o-2024-08-06"
# Initialize clients
api_key = os.getenv("OPENAI_API_KEY")
client = OpenAI(api_key=api_key)
qdrant_client = QdrantClient(host=Config.QDRANT_HOST, port=Config.QDRANT_PORT)
def encode_image(image: Image) -> str:
"""
Encodes a PIL Image into a base64 string.
"""
buffered = io.BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def pdf_to_base64_images(pdf_path: str) -> List[str]:
"""
Converts each page of a PDF into a base64-encoded PNG image.
"""
try:
pdf_document = fitz.open(pdf_path)
base64_images = []
for page_num in range(len(pdf_document)):
page = pdf_document.load_page(page_num)
pix = page.get_pixmap()
img = Image.open(io.BytesIO(pix.tobytes()))
base64_image = encode_image(img)
base64_images.append(base64_image)
logger.info(f"Processed {len(base64_images)} pages from {pdf_path}")
return base64_images
except Exception as e:
logger.error(f"Error processing PDF {pdf_path}: {e}")
return []
def save_base64_image(base64_image: str, folder: str, filename: str) -> str:
"""
Saves a base64 encoded image to a file and returns the file path.
"""
os.makedirs(folder, exist_ok=True)
file_path = os.path.join(folder, filename)
with open(file_path, "w") as f:
f.write(base64_image)
return file_path
def process_folder(folder: str, base64_output_folder: str) -> List[Dict[str, str]]:
"""
Processes all PDFs in a folder and extracts base64 images along with their quarter information.
"""
images_data = []
quarter_pattern = r'Q[1-4]\d{2}'
for file in os.listdir(folder):
if file.endswith(".pdf"):
match = re.search(quarter_pattern, file)
if match:
quarter_info = match.group()
pdf_path = os.path.join(folder, file)
base64_images = pdf_to_base64_images(pdf_path)
for i, base64_image in enumerate(base64_images):
base64_filename = f"{os.path.splitext(file)[0]}_page_{i}.txt"
base64_path = save_base64_image(
base64_image, base64_output_folder, base64_filename)
images_data.append({
'quarter_info': quarter_info,
'base64_image_path': base64_path,
'original_pdf_path': pdf_path
})
else:
logger.warning(
f"No quarter information found in filename: {file}")
return images_data
def analyze_image(base64_image: str, quarter_info: str) -> Dict:
system_prompt = f"""
Analyze the image below and determine if it contains graphs or tabular data.
- If the image contains a table:
- Shorten the table title to one of ["Free_Cash_Flow_Reconciliation", "Free_Cash_Flow_Less_Principal_Repayments", "Free_Cash_Flow_Less_Equipment_Finance Leases"].
- Transcribe the table's title under the "content_output" key.
- Set "image_category" to "table".
- If the image contains graphs:
- Set "image_category" to "graphs".
- Provide a detailed analysis/summary of the graphs, including:
- **Descriptions** of what each graph represents.
- **Key data points** presented as bullet points or numbered lists.
- **Insights or takeaways or trends** derived from the graphs.
The quarter information is: {quarter_info}. Please use that as the value for the JSON key "quarter_info".
"""
response = client.chat.completions.create(
model=Config.GPT_MODEL,
response_format={
"type": "json_schema",
"json_schema": {
"name": "image_analysis",
"schema": {
"type": "object",
"properties": {
"image_category": {"type": "string"},
"content_output": {"type": "string"},
"quarter_info": {"type": "string"}
},
"required": ["image_category", "content_output", "quarter_info"],
"additionalProperties": False
},
"strict": True
}
},
messages=[
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {
"url": f"data:image/png;base64,{base64_image}", "detail": "high"}}
]
}
],
temperature=0.0,
)
response_string = response.choices[0].message.content
response_dict = json.loads(response_string)
return response_dict
def parse_table(base64_image: str, table_title: str, report_date: str) -> Dict:
"""
Parses a table from an image, formats it according to a predefined JSON schema,
and saves the resulting JSON to the TABLE_JSON_FOLDER.
"""
relevant_schema = schema_dict.get(table_title)
if relevant_schema is None:
logger.warning(f"No schema found for table title: {table_title}")
return {}
# Convert the schema to a formatted JSON string
system_prompt = f"""
You are an AI assistant tasked with extracting and structuring data from images containing tables.
**Instructions:**
- Extract all data from the provided table image.
- Format the extracted data according to the JSON schema provided below.
- Ensure that all fields are correctly populated and adhere strictly to the schema specifications.
- Use the following values for additional fields:
- `"title"`: "{table_title}"
- `"report_date"`: "{report_date}"
**Output Format:**
- Provide the output strictly in JSON format without any additional text or explanations.
"""
messages = [
{
"role": "system",
"content": system_prompt
},
{
"role": "user",
"content": [
{"type": "text", "text": "Please extract and format the data from the following table image according to the provided JSON schema."},
{"type": "image_url", "image_url": {
"url": f"data:image/png;base64,{base64_image}", "detail": "high"}}
]
}
]
try:
response = client.chat.completions.create(
model=Config.GPT_MODEL,
response_format=relevant_schema,
messages=messages,
temperature=0.0
)
response_json = json.loads(response.choices[0].message.content)
os.makedirs(Config.TABLE_JSON_FOLDER, exist_ok=True)
filename = f"{table_title}_{report_date}.json".replace(" ", "_")
file_path = os.path.join(Config.TABLE_JSON_FOLDER, filename)
with open(file_path, 'w') as json_file:
json.dump(response_json, json_file, indent=4)
logger.info(f"Saved parsed table JSON to {file_path}")
return response_json
except Exception as e:
logger.error(f"Error in parse_table: {e}")
return {}
def process_single_image(image_data: Dict[str, str]) -> Dict:
"""
Processes a single image: analyzes it and, if it's a table, parses it.
"""
with open(image_data['base64_image_path'], 'r') as f:
base64_image = f.read()
quarter_info = image_data['quarter_info']
analysis = analyze_image(base64_image, quarter_info)
if not analysis:
return {}
analysis['base64_image_path'] = image_data['base64_image_path']
analysis['original_pdf_path'] = image_data['original_pdf_path']
if analysis.get('image_category') == 'table':
table_title = analysis['content_output']
parsed_data = parse_table(base64_image, table_title, quarter_info)
analysis['parsed_table_data'] = parsed_data
return analysis
def process_images_concurrently(images_data: List[Dict[str, str]]) -> List[Dict]:
"""
Processes all images concurrently using ThreadPoolExecutor.
"""
with ThreadPoolExecutor() as executor:
results = list(executor.map(process_single_image, images_data))
return results
def get_embedding(text: str, model: str = Config.EMBEDDING_MODEL) -> List[float]:
"""
Retrieves the embedding for the provided text using OpenAI's embedding model.
"""
text = text.replace("\n", " ")
return client.embeddings.create(input=[text], model=model).data[0].embedding
def create_qdrant_collection(collection_name: str, vector_size: int, distance_metric: str = 'Cosine'):
"""
Creates a Qdrant collection with the specified configuration.
"""
qdrant_client.recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=distance_metric
)
)
logger.info(
f"Collection '{collection_name}' created with vector size {vector_size} and distance metric '{distance_metric}'.")
def insert_data_to_qdrant(
client: QdrantClient,
collection_name: str,
embeddings: List[List[float]],
payloads: List[Dict],
ids: List[int] = None
):
"""
Inserts embeddings and their associated payloads into the specified Qdrant collection.
"""
if ids is None:
ids = list(range(len(embeddings)))
points = [
models.PointStruct(
id=idx,
vector=embedding,
payload=payload
)
for idx, embedding, payload in zip(ids, embeddings, payloads)
]
client.upsert(
collection_name=collection_name,
points=points
)
logger.info(
f"Inserted {len(points)} records into collection '{collection_name}'.")
def query_qdrant(
query: str,
collection_name: str,
top_k: int = 1,
embedding_model: str = Config.EMBEDDING_MODEL
) -> List[Tuple[str, str, str, str]]:
"""
Queries the Qdrant collection with the provided query string and returns the top_k results.
"""
embedded_query = get_embedding(query, model=embedding_model)
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=embedded_query,
limit=top_k
)
output = []
for result in search_results:
payload = result.payload
title = f"{payload['image_category']} - {payload['quarter_info']}"
text = payload['content_output']
base64_image_path = payload['base64_image_path']
original_pdf_path = payload['original_pdf_path']
output.append((title, text, base64_image_path, original_pdf_path))
return output
def ask_database(conn, query):
"""Function to query SQLite database with a provided SQL query."""
try:
results = str(conn.execute(query).fetchall())
except Exception as e:
results = f"query failed with error: {e}"
return results
conn = sqlite3.connect("./earnings.db")
print("Opened database successfully")
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
[
f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
for table in database_schema_dict
]
)
TOOLS = [
{
"type": "function",
"function": {
"name": "ask_database",
"description": "Use this function to retrieve structured financial data from the SQL database. Input should be a fully formed SQL query.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": f"""
SQL query extracting the necessary information to answer the user's question.
The SQL query should be written using the following database schema:
{database_schema_string}
Ensure the query is syntactically correct and returns the data needed to fully address the user's request.
The query should be returned in plain text, not in JSON.
""",
}
},
"required": ["query"],
"additionalProperties": False,
},
}
},
{
"type": "function",
"function": {
"name": "query_qdrant",
"description": "Use this function to handle queries that require semantic understanding or retrieval from unstructured data sources using vector embeddings. Suitable for complex questions, trend analyses, and contextual information not directly available in the SQL database.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "A detailed and clear version of the user's original query capturing the intent and context of the information being sought.",
},
"filter": {
"type": "string",
"description": "Optional. Specific keywords or phrases to narrow down the search results for more precise retrieval. Should consist of 2 or 3 relevant words. Leave empty if no specific filter is needed.",
},
"top_k": {
"type": "integer",
"description": "Optional. The number of top relevant results to retrieve. Use higher numbers (e.g., 50) for broader queries requiring extensive information. Defaults to 10 if not specified.",
"default": 1,
},
},
"required": ["query"],
"additionalProperties": False,
},
}
}
]
class RAGSystem:
def __init__(self):
self.collection_name = Config.COLLECTION_NAME
def process_folder(self):
images_data = process_folder(
Config.SLIDES_FOLDER, Config.BASE64_OUTPUT_FOLDER)
if not images_data:
logger.info("No images to process.")
return None
return images_data
def analyze_images(self, images_data):
image_categorizations = process_images_concurrently(images_data)
logger.info(f"Processed {len(image_categorizations)} images.")
return [item for item in image_categorizations if item]
def prepare_data_for_indexing(self, image_categorizations):
non_table_images = [item for item in image_categorizations if item.get(
'image_category') != 'table']
if not non_table_images:
logger.info("No non-table images to process.")
return None, None
texts = [item['content_output'] for item in non_table_images]
embeddings = [get_embedding(text) for text in texts]
payloads = [
{
"image_category": item['image_category'],
"content_output": item['content_output'],
"quarter_info": item['quarter_info'],
"base64_image_path": item['base64_image_path'],
"original_pdf_path": item['original_pdf_path']
}
for item in non_table_images
]
return embeddings, payloads
def create_and_populate_collection(self, embeddings, payloads):
vector_size = len(embeddings[0])
create_qdrant_collection(self.collection_name, vector_size)
insert_data_to_qdrant(
qdrant_client, self.collection_name, embeddings, payloads)
def query(self, query_text: str, top_k: int = 1) -> List[Tuple[str, str, str, str]]:
return query_qdrant(query_text, top_k)
def generate_response(self, query: str, retrieved_results: List[Tuple[str, str, str, str]]) -> str:
system_prompt = """You are an AI assistant specializing in analyzing financial documents and graphs.
Use the provided information and images to answer the user's query accurately and concisely."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": [{"type": "text", "text": query}]}
]
for title, text, base64_image_path, _ in retrieved_results:
with open(base64_image_path, 'r') as f:
base64_image = f.read()
messages.append({
"role": "user",
"content": [
{"type": "text", "text": f"Title: {title}\nContent: {text}"},
{"type": "image_url", "image_url": {
"url": f"data:image/png;base64,{base64_image}"}}
]
})
response = client.chat.completions.create(
model=Config.GPT_MODEL,
messages=messages,
max_tokens=300,
temperature=0.5
)
return response.choices[0].message.content
def process_and_index_data():
"""
Processes PDFs, extracts images, analyzes them, generates embeddings, and indexes them into Qdrant.
"""
rag_system = RAGSystem()
images_data = rag_system.process_folder()
if images_data:
image_categorizations = rag_system.analyze_images(images_data)
embeddings, payloads = rag_system.prepare_data_for_indexing(
image_categorizations)
if embeddings and payloads:
rag_system.create_and_populate_collection(embeddings, payloads)
logger.info("Data processing and indexing completed successfully.")
else:
logger.warning("No embeddings or payloads generated for indexing.")
else:
logger.warning("No images data found for processing.")
def query_qdrant(query: str, top_k: int = 1) -> str:
"""Query Qdrant to retrieve relevant documents based on the query."""
try:
embedding = get_embedding(query)
if not embedding:
return "Failed to retrieve embedding for the query."
search_result = qdrant_client.search(
collection_name=Config.COLLECTION_NAME,
query_vector=embedding,
limit=top_k
)
output = []
for result in search_result:
payload = result.payload
title = f"{payload['image_category']} - {payload['quarter_info']}"
text = payload['content_output']
base64_image_path = payload['base64_image_path']
original_pdf_path = payload['original_pdf_path']
output.append((title, text, base64_image_path, original_pdf_path))
return output
except Exception as e:
logger.error(f"Qdrant query failed: {e}")
return f"Error querying Qdrant: {e}"
def query_rag_system(user_query):
"""
Starts an interactive query loop for retrieving and responding to user queries.
"""
rag_system = RAGSystem()
results = rag_system.query(user_query)
if results:
for title, text, base64_image_path, original_pdf_path in results:
logger.info(f"Title: {title}")
logger.info(f"Content: {text}")
logger.info(f"Base64 Image Path: {base64_image_path}")
logger.info(f"Original PDF Path: {original_pdf_path}")
logger.info("---")
# Generate response with retrieved information and images
response = rag_system.generate_response(user_query, results)
return response
def main_loop():
"""Interactive loop for processing user queries."""
print("Welcome to the Financial Assistant. Type 'exit' to quit.\n")
process_and_index_data()
while True:
user_query = input("User: ")
if user_query.lower() in ["exit", "quit"]:
print("Exiting the assistant. Goodbye!")
break
messages = [
{"role": "system", "content": TRIAGE_SYSTEM_PROMPT},
{"role": "user", "content": user_query},
]
response = client.chat.completions.create(
model='gpt-4o',
messages=messages,
tools=TOOLS,
tool_choice="required")
# Step 2: determine if the response from the model includes a tool call.
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
# If true the model will return the name of the tool / function to call and the argument(s)
tool_call_id = tool_calls[0].id
tool_function_name = tool_calls[0].function.name
tool_query_string = json.loads(
tool_calls[0].function.arguments)['query']
# Step 3: Call the function and retrieve results. Append the results to the messages list.
if tool_function_name == 'ask_database':
results = ask_database(conn, tool_query_string)
messages.append({
"role": "tool",
"tool_call_id": tool_call_id,
"name": tool_function_name,
"content": results
})
elif tool_function_name == 'query_qdrant':
results = query_rag_system(user_query)
print(results)
if __name__ == "__main__":
main_loop()