supporting-blog-content/langraph-retrieval-agent-template-demo/src/retrieval_graph/graph.py (117 lines of code) (raw):
"""Main entrypoint for the conversational retrieval graph.
This module defines the core structure and functionality of the conversational
retrieval graph. It includes the main graph definition, state management,
and key functions for processing user inputs, generating queries, retrieving
relevant documents, and formulating responses.
"""
from datetime import datetime, timezone
from typing import cast
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import RunnableConfig
from langgraph.graph import StateGraph
import logging
from langchain_core.messages import HumanMessage
from retrieval_graph import retrieval
from retrieval_graph.configuration import Configuration
from retrieval_graph.state import InputState, State
from retrieval_graph.utils import format_docs, get_message_text, load_chat_model
# Define the function that calls the model
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class SearchQuery(BaseModel):
"""Search the indexed documents for a query."""
query: str
async def generate_query(
state: State, *, config: RunnableConfig
) -> dict[str, list[str]]:
"""Generate a search query based on the current state and configuration.
This function analyzes the messages in the state and generates an appropriate
search query. For the first message, it uses the user's input directly.
For subsequent messages, it uses a language model to generate a refined query.
Args:
state (State): The current state containing messages and other information.
config (RunnableConfig | None, optional): Configuration for the query generation process.
Returns:
dict[str, list[str]]: A dictionary with a 'queries' key containing a list of generated queries.
Behavior:
- If there's only one message (first user input), it uses that as the query.
- For subsequent messages, it uses a language model to generate a refined query.
- The function uses the configuration to set up the prompt and model for query generation.
"""
messages = state.messages
if len(messages) == 1:
# It's the first user question. We will use the input directly to search.
human_input = get_message_text(messages[-1])
return {"queries": [human_input]}
else:
configuration = Configuration.from_runnable_config(config)
# Feel free to customize the prompt, model, and other logic!
prompt = ChatPromptTemplate.from_messages(
[
("system", configuration.query_system_prompt),
("placeholder", "{messages}"),
]
)
model = load_chat_model(configuration.query_model).with_structured_output(
SearchQuery
)
message_value = await prompt.ainvoke(
{
"messages": state.messages,
"queries": "\n- ".join(state.queries),
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
)
generated = cast(SearchQuery, await model.ainvoke(message_value, config))
return {
"queries": [generated.query],
}
async def retrieve(
state: State, *, config: RunnableConfig
) -> dict[str, list[Document]]:
"""Retrieve documents based on the latest query in the state.
This function takes the current state and configuration, uses the latest query
from the state to retrieve relevant documents using the retriever, and returns
the retrieved documents.
Args:
state (State): The current state containing queries and the retriever.
config (RunnableConfig | None, optional): Configuration for the retrieval process.
Returns:
dict[str, list[Document]]: A dictionary with a single key "retrieved_docs"
containing a list of retrieved Document objects.
"""
with retrieval.make_retriever(config) as retriever:
querys = state.queries[-1]
response = await retriever.ainvoke(querys, config)
return {"retrieved_docs": response}
async def respond(
state: State, *, config: RunnableConfig
) -> dict[str, list[BaseMessage]]:
"""Call the LLM powering our "agent"."""
configuration = Configuration.from_runnable_config(config)
# Feel free to customize the prompt, model, and other logic!
prompt = ChatPromptTemplate.from_messages(
[
("system", configuration.response_system_prompt),
("placeholder", "{messages}"),
]
)
model = load_chat_model(configuration.response_model)
retrieved_docs = format_docs(state.retrieved_docs)
message_value = await prompt.ainvoke(
{
"messages": state.messages,
"retrieved_docs": retrieved_docs,
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
)
response = await model.ainvoke(message_value, config)
# We return a list, because this will get added to the existing list
return {"response": [response]}
# Define a new graph (It's just a pipe)
async def predict_query(
state: State, *, config: RunnableConfig
) -> dict[str, list[BaseMessage]]:
configuration = Configuration.from_runnable_config(config)
prompt = ChatPromptTemplate.from_messages(
[
("system", configuration.predict_next_question_prompt),
("placeholder", "{messages}"),
]
)
model = load_chat_model(configuration.response_model)
user_query = state.queries[-1] if state.queries else "No prior query available"
previous_queries = "\n- ".join(state.queries) if state.queries else "None"
retrieved_docs = format_docs(state.retrieved_docs)
message_value = await prompt.ainvoke(
{
"messages": state.messages,
"retrieved_docs": retrieved_docs,
"previous_queries": previous_queries,
"user_query": user_query, # Use the most recent query as primary input
"system_time": datetime.now(tz=timezone.utc).isoformat(),
},
config,
)
next_question = await model.ainvoke(message_value, config)
return {"next_question": [next_question]}
builder = StateGraph(State, input=InputState, config_schema=Configuration)
builder.add_node(generate_query)
builder.add_node(retrieve)
builder.add_node(respond)
builder.add_node(predict_query)
builder.add_edge("__start__", "generate_query")
builder.add_edge("generate_query", "retrieve")
builder.add_edge("retrieve", "respond")
builder.add_edge("respond", "predict_query")
# Finally, we compile it!
# This compiles it into a graph you can invoke and deploy.
graph = builder.compile(
interrupt_before=[], # if you want to update the state before calling the tools
interrupt_after=[],
)
graph.name = "RetrievalGraph"