in databao/executors/lighthouse/graph.py [0:0]
def compile(self, model_config: LLMConfig) -> CompiledStateGraph[Any]:
tools = self.make_tools()
llm_model = model_config.new_chat_model()
model_with_tools = self._model_bind_tools(llm_model, tools)
def llm_node(state: AgentState) -> dict[str, Any]:
messages = state["messages"]
response = self._chat(messages, model_config, model_with_tools)
return {"messages": [response[-1]]}
def tool_executor_node(state: AgentState) -> dict[str, Any]:
last_message = state["messages"][-1]
tool_messages = []
assert isinstance(last_message, AIMessage)
tool_calls = last_message.tool_calls
is_ready_for_user = any(tc["name"] == "submit_result" for tc in tool_calls)
if is_ready_for_user:
if len(tool_calls) > 1:
tool_messages = [
ToolMessage("submit_result must be the only tool call.", tool_call_id=tool_call["id"])
for tool_call in tool_calls
]
return {"messages": tool_messages, "ready_for_user": False}
else:
tool_call = tool_calls[0]
if "query_ids" not in state or len(state["query_ids"]) == 0:
tool_messages = [
ToolMessage("No queries have been executed yet.", tool_call_id=tool_call["id"])
]
return {"messages": tool_messages, "ready_for_user": False}
query_id = tool_call["args"]["query_id"]
if query_id not in state["query_ids"]:
available_ids = ", ".join(state["query_ids"].keys())
tool_messages = [
ToolMessage(
f"Query ID {query_id} not found. Available query IDs: {available_ids}",
tool_call_id=tool_call["id"],
)
]
return {"messages": tool_messages, "ready_for_user": False}
target_tool_message = state["query_ids"][query_id]
if target_tool_message.artifact is None or "df" not in target_tool_message.artifact:
tool_messages = [
ToolMessage(f"Query {query_id} does not have a valid result.", tool_call_id=tool_call["id"])
]
return {"messages": tool_messages, "ready_for_user": False}
query_ids = dict(state.get("query_ids", {}))
sql = state.get("sql")
df = state.get("df")
visualization_prompt = state.get("visualization_prompt", "")
message_index = len(state["messages"]) - 1
for idx, tool_call in enumerate(tool_calls):
name = tool_call["name"]
args = tool_call["args"]
tool_call_id = tool_call["id"]
# Find the tool by name
tool = next((t for t in tools if t.name == name), None)
if tool is None:
tool_messages.append(ToolMessage(content=f"Tool {name} does not exist!", tool_call_id=tool_call_id))
continue
try:
result = tool.invoke(args | {"graph_state": state})
except Exception as e:
result = {"error": exception_to_string(e) + f"\nTool: {name}, Args: {args}"}
content = ""
if name == "run_sql_query":
sql = result.get("sql")
df = result.get("df")
# Generate query_id using message index and tool call index
query_id = f"{message_index}-{idx}"
# Override the query_id in the result
result["query_id"] = query_id
content = result.get("csv", result.get("error", ""))
if "csv" in result:
content = f"query_id='{query_id}'\n\n{content}"
if query_id:
query_ids[query_id] = ToolMessage(
content=content,
tool_call_id=tool_call_id,
artifact=result,
)
elif name == "submit_result":
content = str(result)
query_id = tool_call["args"]["query_id"]
visualization_prompt = tool_call["args"].get("visualization_prompt", "")
sql = state["query_ids"][query_id].artifact["sql"]
df = state["query_ids"][query_id].artifact["df"]
tool_messages.append(ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result))
if name == "submit_result":
return {
"messages": tool_messages,
"sql": sql,
"df": df,
"visualization_prompt": visualization_prompt,
"ready_for_user": True,
}
return {
"messages": tool_messages,
"query_ids": query_ids,
"sql": sql,
"df": df,
"visualization_prompt": visualization_prompt,
"ready_for_user": False,
}
def should_continue(state: AgentState) -> Literal["tool_executor", "end"]:
# Check if there are tool calls in the last message
last_message = state["messages"][-1]
if isinstance(last_message, AIMessage) and last_message.tool_calls:
return "tool_executor"
return "end"
def should_finish(state: AgentState) -> Literal["llm_node", "end"]:
# Check if we just executed submit_result - if so, end the conversation
if state.get("ready_for_user", False):
return "end"
return "llm_node"
graph = StateGraph(AgentState)
graph.add_node("llm_node", llm_node)
graph.add_node("tool_executor", tool_executor_node)
graph.add_edge(START, "llm_node")
graph.add_conditional_edges("llm_node", should_continue, {"tool_executor": "tool_executor", "end": END})
graph.add_conditional_edges("tool_executor", should_finish, {"llm_node": "llm_node", "end": END})
return graph.compile()