from collections.abc import Sequence
from typing import Annotated, Any, Literal

import pandas as pd
from duckdb import DuckDBPyConnection
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool, tool
from langchain_openai import ChatOpenAI
from langgraph.constants import END, START
from langgraph.graph import add_messages
from langgraph.graph.state import CompiledStateGraph, StateGraph
from langgraph.prebuilt import InjectedState
from typing_extensions import TypedDict

from databao.configs.llm import LLMConfig
from databao.core import ExecutionResult
from databao.duckdb.react_tools import execute_duckdb_sql
from databao.executors.frontend.text_frontend import dataframe_to_markdown
from databao.executors.lighthouse.utils import exception_to_string


class AgentState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]
    query_ids: dict[str, ToolMessage]
    sql: str | None
    df: pd.DataFrame | None
    visualization_prompt: str | None
    ready_for_user: bool
    limit_max_rows: int | None


def get_query_ids_mapping(messages: list[BaseMessage]) -> dict[str, ToolMessage]:
    query_ids = {}
    for message in messages:
        if isinstance(message, ToolMessage) and isinstance(message.artifact, dict) and "query_id" in message.artifact:
            query_ids[message.artifact["query_id"]] = message
    return query_ids


class ExecuteSubmit:
    """Simple graph with two tools: run_sql_query and submit_result.
    All context must be in the SystemMessage."""

    MAX_TOOL_ROWS = 12
    """Max number of rows to return in SQL tool calls."""

    def __init__(self, connection: DuckDBPyConnection):
        self._connection = connection

    def init_state(self, messages: list[BaseMessage], *, limit_max_rows: int | None = None) -> AgentState:
        return AgentState(
            messages=messages,
            query_ids=get_query_ids_mapping(messages),
            sql=None,
            df=None,
            visualization_prompt=None,
            ready_for_user=False,
            limit_max_rows=limit_max_rows,
        )

    def get_result(self, state: AgentState) -> ExecutionResult:
        last_ai_message = None
        for m in reversed(state["messages"]):
            if isinstance(m, AIMessage):
                last_ai_message = m
                break
        if last_ai_message is None:
            raise RuntimeError("No AI message found in message log")
        if len(last_ai_message.tool_calls) == 0:
            # Sometimes models don't call the submit_result tool, but we still want to return some dataframe.
            sql = state.get("sql", "")
            df = state.get("df")  # Latest df result (usually from run_sql_query)
            visualization_prompt = state.get("visualization_prompt")
            result = ExecutionResult(
                text=last_ai_message.text(),
                df=df,
                code=sql,
                meta={
                    "visualization_prompt": visualization_prompt,
                    "messages": state["messages"],
                    "submit_called": False,
                },
            )
        elif len(last_ai_message.tool_calls) > 1:
            raise RuntimeError("Expected exactly one tool call in AI message")
        elif last_ai_message.tool_calls[0]["name"] != "submit_result":
            raise RuntimeError(
                f"Expected submit_result tool call in AI message, got {last_ai_message.tool_calls[0]['name']}"
            )
        else:
            sql = state.get("sql", "")
            df = state.get("df")
            tool_call = last_ai_message.tool_calls[0]
            text = tool_call["args"]["result_description"]
            visualization_prompt = state.get("visualization_prompt", "")
            result = ExecutionResult(
                text=text,
                df=df,
                code=sql,
                meta={
                    "visualization_prompt": visualization_prompt,
                    "messages": state["messages"],
                    "submit_called": True,
                },
            )
        return result

    def make_tools(self) -> list[BaseTool]:
        @tool(parse_docstring=True)
        def run_sql_query(sql: str, graph_state: Annotated[AgentState, InjectedState]) -> dict[str, Any]:
            """
            Run a SELECT SQL query in the database. Returns the first 12 rows in csv format.

            Args:
                sql: SQL query
            """
            try:
                # TODO use ToolRuntime in LangChain v1.0
                limit = graph_state["limit_max_rows"]
                df = execute_duckdb_sql(sql, self._connection, limit=limit)
                df_csv = df.head(self.MAX_TOOL_ROWS).to_csv(index=False)
                df_markdown = dataframe_to_markdown(df.head(self.MAX_TOOL_ROWS), index=False)
                if len(df) > self.MAX_TOOL_ROWS:
                    df_csv += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows."
                    df_markdown += f"\nResult is truncated from {len(df)} to {self.MAX_TOOL_ROWS} rows."
                return {"df": df, "sql": sql, "csv": df_csv, "markdown": df_markdown}
            except Exception as e:
                return {"error": exception_to_string(e)}

        @tool(parse_docstring=True)
        def submit_result(
            query_id: str,
            result_description: str,
            visualization_prompt: str,
        ) -> str:
            """
            Call this tool with the ID of the query you want to submit to the user.
            This will return control to the user and must always be the last tool call.
            The user will see the full query result, not just the first 12 rows. Returns a confirmation message.

            Args:
                query_id: The ID of the query to submit (query_ids are automatically generated when you run queries).
                result_description: A comment to a final result. This will be included in the final result.
                visualization_prompt: Optional visualization prompt. If not empty, a Vega-Lite visualization agent
                    will be asked to plot the submitted query data according to instructions in the prompt.
                    The instructions should be short and simple.
            """
            return f"Query {query_id} submitted successfully. Your response is now visible to the user."

        tools = [run_sql_query, submit_result]
        return tools

    def compile(self, model_config: LLMConfig) -> CompiledStateGraph[Any]:
        tools = self.make_tools()
        llm_model = model_config.new_chat_model()

        model_with_tools = self._model_bind_tools(llm_model, tools)

        def llm_node(state: AgentState) -> dict[str, Any]:
            messages = state["messages"]
            response = self._chat(messages, model_config, model_with_tools)
            return {"messages": [response[-1]]}

        def tool_executor_node(state: AgentState) -> dict[str, Any]:
            last_message = state["messages"][-1]
            tool_messages = []
            assert isinstance(last_message, AIMessage)

            tool_calls = last_message.tool_calls

            is_ready_for_user = any(tc["name"] == "submit_result" for tc in tool_calls)
            if is_ready_for_user:
                if len(tool_calls) > 1:
                    tool_messages = [
                        ToolMessage("submit_result must be the only tool call.", tool_call_id=tool_call["id"])
                        for tool_call in tool_calls
                    ]
                    return {"messages": tool_messages, "ready_for_user": False}
                else:
                    tool_call = tool_calls[0]

                    if "query_ids" not in state or len(state["query_ids"]) == 0:
                        tool_messages = [
                            ToolMessage("No queries have been executed yet.", tool_call_id=tool_call["id"])
                        ]
                        return {"messages": tool_messages, "ready_for_user": False}

                    query_id = tool_call["args"]["query_id"]
                    if query_id not in state["query_ids"]:
                        available_ids = ", ".join(state["query_ids"].keys())
                        tool_messages = [
                            ToolMessage(
                                f"Query ID {query_id} not found. Available query IDs: {available_ids}",
                                tool_call_id=tool_call["id"],
                            )
                        ]
                        return {"messages": tool_messages, "ready_for_user": False}

                    target_tool_message = state["query_ids"][query_id]
                    if target_tool_message.artifact is None or "df" not in target_tool_message.artifact:
                        tool_messages = [
                            ToolMessage(f"Query {query_id} does not have a valid result.", tool_call_id=tool_call["id"])
                        ]
                        return {"messages": tool_messages, "ready_for_user": False}

            query_ids = dict(state.get("query_ids", {}))
            sql = state.get("sql")
            df = state.get("df")
            visualization_prompt = state.get("visualization_prompt", "")

            message_index = len(state["messages"]) - 1

            for idx, tool_call in enumerate(tool_calls):
                name = tool_call["name"]
                args = tool_call["args"]
                tool_call_id = tool_call["id"]
                # Find the tool by name
                tool = next((t for t in tools if t.name == name), None)
                if tool is None:
                    tool_messages.append(ToolMessage(content=f"Tool {name} does not exist!", tool_call_id=tool_call_id))
                    continue

                try:
                    result = tool.invoke(args | {"graph_state": state})
                except Exception as e:
                    result = {"error": exception_to_string(e) + f"\nTool: {name}, Args: {args}"}

                content = ""
                if name == "run_sql_query":
                    sql = result.get("sql")
                    df = result.get("df")
                    # Generate query_id using message index and tool call index
                    query_id = f"{message_index}-{idx}"
                    # Override the query_id in the result
                    result["query_id"] = query_id
                    content = result.get("csv", result.get("error", ""))
                    if "csv" in result:
                        content = f"query_id='{query_id}'\n\n{content}"
                    if query_id:
                        query_ids[query_id] = ToolMessage(
                            content=content,
                            tool_call_id=tool_call_id,
                            artifact=result,
                        )
                elif name == "submit_result":
                    content = str(result)
                    query_id = tool_call["args"]["query_id"]
                    visualization_prompt = tool_call["args"].get("visualization_prompt", "")
                    sql = state["query_ids"][query_id].artifact["sql"]
                    df = state["query_ids"][query_id].artifact["df"]
                tool_messages.append(ToolMessage(content=content, tool_call_id=tool_call_id, artifact=result))
                if name == "submit_result":
                    return {
                        "messages": tool_messages,
                        "sql": sql,
                        "df": df,
                        "visualization_prompt": visualization_prompt,
                        "ready_for_user": True,
                    }
            return {
                "messages": tool_messages,
                "query_ids": query_ids,
                "sql": sql,
                "df": df,
                "visualization_prompt": visualization_prompt,
                "ready_for_user": False,
            }

        def should_continue(state: AgentState) -> Literal["tool_executor", "end"]:
            # Check if there are tool calls in the last message
            last_message = state["messages"][-1]
            if isinstance(last_message, AIMessage) and last_message.tool_calls:
                return "tool_executor"
            return "end"

        def should_finish(state: AgentState) -> Literal["llm_node", "end"]:
            # Check if we just executed submit_result - if so, end the conversation
            if state.get("ready_for_user", False):
                return "end"
            return "llm_node"

        graph = StateGraph(AgentState)
        graph.add_node("llm_node", llm_node)
        graph.add_node("tool_executor", tool_executor_node)

        graph.add_edge(START, "llm_node")
        graph.add_conditional_edges("llm_node", should_continue, {"tool_executor": "tool_executor", "end": END})
        graph.add_conditional_edges("tool_executor", should_finish, {"llm_node": "llm_node", "end": END})
        return graph.compile()

    @staticmethod
    def _model_bind_tools(
        model: BaseChatModel, tools: Sequence[BaseTool], **kwargs: Any
    ) -> Runnable[LanguageModelInput, BaseMessage]:
        if isinstance(model, ChatOpenAI):
            return model.bind_tools(tools, strict=True, **kwargs)
        else:
            return model.bind_tools(tools, **kwargs)

    @staticmethod
    def _chat(
        messages: list[BaseMessage],
        config: LLMConfig,
        model: Runnable[list[BaseMessage], Any] | None = None,
    ) -> list[BaseMessage]:
        if model is None:
            model = config.new_chat_model()
        messages = ExecuteSubmit._apply_system_prompt_caching(config, messages)
        response: AIMessage = ExecuteSubmit._call_model(model, messages)
        return [*messages, response]

    @staticmethod
    def _is_anthropic_model(config: LLMConfig) -> bool:
        """Check if the model is an Anthropic model based on the config name."""
        return "claude" in config.name.lower()

    @staticmethod
    def _apply_system_prompt_caching(config: LLMConfig, messages: list[BaseMessage]) -> list[BaseMessage]:
        """Apply system prompt caching for Anthropic models."""
        if not (config.cache_system_prompt and ExecuteSubmit._is_anthropic_model(config)):
            return messages
        # Assume only the first message can be a system prompt.
        assert all(m.type != "system" for m in messages[1:])
        if messages[0].type == "system":
            messages = [ExecuteSubmit._set_message_cache_breakpoint(config, messages[0]), *messages[1:]]
        return messages

    @staticmethod
    def _set_message_cache_breakpoint(config: LLMConfig, message: BaseMessage) -> BaseMessage:
        """Enable prompt caching for this message (for Anthropic models).

        If you have a list of messages, set a breakpoint only on the last message to automatically
        cache all previous messages.

        See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
        > Prompt caching references the entire prompt - tools, system, and messages (in that order) up to and including
            the block designated with cache_control.
        """
        if not ExecuteSubmit._is_anthropic_model(config):
            return message
        new_content: list[dict[str, Any] | str]
        match message.content:
            case str() | dict():
                new_content = [ExecuteSubmit._set_anthropic_cache_breakpoint(message.content)]
            case list():
                # Set checkpoint only for the last message
                new_content = message.content.copy()
                new_content[-1] = ExecuteSubmit._set_anthropic_cache_breakpoint(new_content[-1])
        return message.model_copy(update={"content": new_content})

    @staticmethod
    def _set_anthropic_cache_breakpoint(content: str | dict[str, Any]) -> dict[str, Any]:
        if isinstance(content, str):
            return {"type": "text", "text": content, "cache_control": {"type": "ephemeral"}}
        elif isinstance(content, dict):
            d = content.copy()
            d["cache_control"] = {"type": "ephemeral"}
            return d
        else:
            raise ValueError(f"Unknown content type: {type(content)}")

    @staticmethod
    def _call_model(model: Runnable[list[BaseMessage], Any], messages: list[BaseMessage]) -> Any:
        return model.with_retry(wait_exponential_jitter=True, stop_after_attempt=3).invoke(messages)
