# Copyright 2025 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=C0301, W0107, W0107, W0622, R0917
"""Module used to define and interact with agent orchestrators."""
from abc import ABC, abstractmethod
from typing import AsyncGenerator, Generator, Dict, Any, Optional
import uuid

from google.api_core import exceptions
from langchain.agents import (
    AgentExecutor,
    create_react_agent as langchain_create_react_agent
)
from langchain_core.messages import  AIMessage
from langchain_google_vertexai import ChatVertexAI
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
from langgraph.prebuilt import create_react_agent as langgraph_create_react_agent
from vertexai.preview import reasoning_engines # TODO: update this when it becomes agent engine

# LlamaIndex
from llama_index.core import Settings
from llama_index.core.agent import ReActAgent
from llama_index.llms.langchain import LangChainLLM

from app.orchestration.constants import GEMINI_FLASH_20_LATEST
from app.orchestration.enums import OrchestrationFramework
from app.orchestration.tools import (
    get_tools,
    get_llamaindex_tools
)


class BaseAgentManager(ABC):
    """
    Abstract base class for Agent Managers.  Defines the common interface
    for creating and managing agent executors with different orchestration
    frameworks.
    """

    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: str,
        orchestration_framework: str,
        agent_engine_resource_id: str,
        model_name: str,
        max_retries: int,
        max_output_tokens: int,
        temperature: float,
        top_p: float,
        top_k: int,
        return_steps: bool,
        verbose: bool
    ):
        """
        Initializes the BaseAgentManager with common configurations.

        Args:
            prompt: System instructions to give to the agent.
            industry_type: The agent industry type to use. Correlates to tool configs.
            location: The GCP location to run the chat model.
            orchestration_framework: The type of agent framework to use.
            agent_engine_resource_id: The Resource ID of the deployed AE Agent (if using AE).
            model_name: The valid name of the LLM to use for the agent.
            max_retries: Maximum number of times to retry the query on a failure.
            max_output_tokens: Maximum amount of text output from one prompt.
            temperature: Temperature to use for the agent.
            top_p: Top p value. Chooses the words based on a cumulative probability threshold.
            top_k: Top k value. Chooses the top k most likely words
            return_steps: Whether to return the agent's trajectory of intermediate
                steps at the end in addition to the final output.
            verbose: Whether or not run in verbose mode.
        """
        self.prompt = prompt
        self.industry_type = industry_type
        self.location = location
        self.orchestration_framework = orchestration_framework
        self.agent_engine_resource_id = agent_engine_resource_id
        self.model_name = model_name
        self.max_retries = max_retries
        self.max_output_tokens = max_output_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.top_k = top_k
        self.return_steps = return_steps
        self.verbose = verbose

        self.model_obj = self.get_model_obj()
        self.tools = self.get_tools()
        self.agent_executor = self.create_agent_executor()


    @abstractmethod
    def create_agent_executor(self):
        """
        Abstract method to create the specific agent executor based on the
        orchestration framework.  This must be implemented by subclasses.

        Returns:
            The initialized agent executor instance.
        """
        pass


    def get_tools(self):
        """
        Helper method to retrieve tools based on the industry type.

        Returns:
            A list of tools for the agent to use, based on the industry type.
        """
        return get_tools(
            self.industry_type,
            self.orchestration_framework
        )


    def get_model_obj(self):
        """
        Helper method to retrieve the model object based on the model name 
        and config.

        Returns:
            An LLM object for the Agent to use.

        Exception:
            The model_name is not found.
        """
        try:
            if "claude" in self.model_name:
                return ChatAnthropicVertex(
                    model_name=self.model_name,
                    max_retries=self.max_retries,
                    ## causes issues with the pydantic model with default None value:
                    # max_output_tokens=self.max_output_tokens,
                    ## for now, claude is only available in us-east5 and europe-west1
                    # location = self.location,
                    location="us-east5",
                    temperature=self.temperature,
                    top_p=self.top_p,
                    top_k=self.top_k,
                    verbose=self.verbose
                )
            else:
                return ChatVertexAI(
                    model_name=self.model_name,
                    location=self.location,
                    max_retries=self.max_retries,
                    max_output_tokens=self.max_output_tokens,
                    temperature=self.temperature,
                    top_p=self.top_p,
                    top_k=self.top_k,
                    verbose=self.verbose
                )
        except exceptions.NotFound as e:
            raise exceptions.NotFound(f"Resource not found. {e}")
        except Exception as e:
            raise RuntimeError(f"Error encountered initalizing model resource. {e}") from e


    @abstractmethod
    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """
        Abstract method to asynchronously stream the Agent output.
        This should be implemented by subclasses to handle the specific
        streaming logic of their agent executor.
        """
        pass


class LangChainPrebuiltAgentManager(BaseAgentManager):
    """
    AgentManager subclass for LangChain Agent orchestration.
    """
    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: Optional[str] = "us-central1",
        agent_engine_resource_id: Optional[str] = None,
        model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
        max_retries: Optional[int] = 6,
        max_output_tokens: Optional[int] = None,
        temperature: Optional[float] = 0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        return_steps: Optional[bool] = False,
        verbose: Optional[bool] = True
    ):
        super().__init__(
            prompt=prompt,
            industry_type=industry_type,
            location=location,
            orchestration_framework=OrchestrationFramework.LANGCHAIN_PREBUILT_AGENT.value,
            agent_engine_resource_id=agent_engine_resource_id,
            model_name=model_name,
            max_retries=max_retries,
            max_output_tokens=max_output_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            return_steps=return_steps,
            verbose=verbose
        )


    def create_agent_executor(self):
        """
        Creates a Langchain React Agent executor.
        """
        react_agent = langchain_create_react_agent(
            prompt=self.prompt,
            llm=self.model_obj,
            tools=self.tools,
        )
        return AgentExecutor(
            agent=react_agent,
            tools=self.tools,
            return_intermediate_steps=self.return_steps,
            verbose=self.verbose
        )


    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """Asynchronously event streams the Agent output.

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during streaming.
        """
        try:
            async for chunk in self.agent_executor.astream(
                {
                    "input": input["messages"][-1],
                    "chat_history": input["messages"][:-1],
                }
            ):
                # Organize response object to be consistent with other Agents (e.g. Agent Engine)
                if "output" in chunk:
                    response_obj = {
                        "agent": {
                            "output": chunk["output"],
                            "messages": [
                                {
                                    "lc": 1,
                                    "type": "constructor",
                                    "id": ["langchain", "schema", "messages", "AIMessage"],
                                    "kwargs": {
                                        "content": chunk["output"],
                                        "type": "ai",
                                        "tool_calls": [],
                                        "invalid_tool_calls": [],
                                        "id": input["run_id"]
                                    }
                                }
                            ]
                        }
                    }
                    yield response_obj
            return  # Exit the loop if successful
        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e


class LangGraphPrebuiltAgentManager(BaseAgentManager):
    """
    AgentManager subclass for LangGraph Agent orchestration.
    """
    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: Optional[str] = "us-central1",
        agent_engine_resource_id: Optional[str] = None,
        model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
        max_retries: Optional[int] = 6,
        max_output_tokens: Optional[int] = None,
        temperature: Optional[float] = 0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        return_steps: Optional[bool] = False,
        verbose: Optional[bool] = True
    ):
        super().__init__(
            prompt=prompt,
            industry_type=industry_type,
            location=location,
            orchestration_framework=OrchestrationFramework.LANGGRAPH_PREBUILT_AGENT.value,
            agent_engine_resource_id=agent_engine_resource_id,
            model_name=model_name,
            max_retries=max_retries,
            max_output_tokens=max_output_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            return_steps=return_steps,
            verbose=verbose
        )


    def create_agent_executor(self):
        """
        Creates a LangGraph React Agent executor.
        """
        return langgraph_create_react_agent(
            prompt=self.prompt,
            model=self.model_obj,
            tools=self.tools,
            debug=self.verbose
        )


    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """Asynchronously event streams the Agent output.

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during streaming.
        """
        try:
            async for chunk in self.agent_executor.astream(input, stream_mode="values"):
                message = chunk["messages"][-1]
                if isinstance(message, AIMessage):
                    # Organize response object to be consistent with other Agents (e.g. Agent Engine)
                    response_obj = {
                        "agent": {
                            "messages": [
                                {
                                    "lc": 1,
                                    "type": "constructor",
                                    "id": ["langgraph", "schema", "messages", "AIMessage"],
                                    "kwargs": {
                                        "content": message.content,
                                        "type": "ai",
                                        "tool_calls": [],
                                        "invalid_tool_calls": [],
                                        "id": input["run_id"]
                                    },
                                    "usage_metadata": message.usage_metadata
                                }
                            ]
                        }
                    }
                    yield response_obj
            return  # Exit the loop if successful
        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e


class LangChainVertexAIAgentEngineAgentManager(BaseAgentManager):
    """
    AgentManager subclass for Vertex AI Agent Engine LangChain orchestration.
    """
    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: Optional[str] = "us-central1",
        agent_engine_resource_id: Optional[str] = None,
        model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
        max_retries: Optional[int] = 6,
        max_output_tokens: Optional[int] = None,
        temperature: Optional[float] = 0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        return_steps: Optional[bool] = False,
        verbose: Optional[bool] = True
    ):
        super().__init__(
            prompt=prompt,
            industry_type=industry_type,
            location=location,
            orchestration_framework=OrchestrationFramework.LANGCHAIN_VERTEX_AI_AGENT_ENGINE_AGENT.value,
            agent_engine_resource_id=agent_engine_resource_id,
            model_name=model_name,
            max_retries=max_retries,
            max_output_tokens=max_output_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            return_steps=return_steps,
            verbose=verbose
        )


    def create_agent_executor(self):
        """
        Creates a Vertex AI Agent Engine Langchain Agent executor.
        """
        # If agent_engine_resource_id is provided, use the deployed Agent Engine
        if self.agent_engine_resource_id:
            langchain_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
        else:
            langchain_agent = reasoning_engines.LangchainAgent(
                # prompt=self.prompt, # Custom prompt seems to have an issue for some reason
                model=self.model_name,
                tools=self.tools,
                agent_executor_kwargs={
                    "return_intermediate_steps": self.return_steps,
                    "verbose": self.verbose
                },
                enable_tracing=True
            )
            langchain_agent.set_up()
        return langchain_agent


    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """
        Asynchronously streams the Agent output using Vertex AI Agent engine.

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during streaming.
        """
        # Convert the messages into a string
        content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
        try:
            for chunk in self.agent_executor.stream_query(input=content):
                chunk["messages"] = [{**msg, "kwargs": {**msg["kwargs"], "id": input["run_id"]}} for msg in chunk["messages"]]
                yield {"agent": chunk}
        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e


class LangGraphVertexAIAgentEngineAgentManager(BaseAgentManager):
    """
    AgentManager subclass for Vertex AI Agent Engine LangGraph orchestration.
    """
    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: Optional[str] = "us-central1",
        agent_engine_resource_id: Optional[str] = None,
        model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
        max_retries: Optional[int] = 6,
        max_output_tokens: Optional[int] = None,
        temperature: Optional[float] = 0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        return_steps: Optional[bool] = False,
        verbose: Optional[bool] = True
    ):
        super().__init__(
            prompt=prompt,
            industry_type=industry_type,
            location=location,
            orchestration_framework=OrchestrationFramework.LANGGRAPH_VERTEX_AI_AGENT_ENGINE_AGENT.value,
            agent_engine_resource_id=agent_engine_resource_id,
            model_name=model_name,
            max_retries=max_retries,
            max_output_tokens=max_output_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            return_steps=return_steps,
            verbose=verbose
        )


    def create_agent_executor(self):
        """
        Creates a Vertex AI Agent Engine LangGraph Agent executor.
        """
        # If agent_engine_resource_id is provided, use the deployed Agent Engine
        if self.agent_engine_resource_id:
            langgraph_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
        else:
            langgraph_agent = reasoning_engines.LanggraphAgent(
                model=self.model_name,
                tools=self.tools,
                runnable_kwargs={"prompt": self.prompt, "debug": self.verbose},
                enable_tracing=True
            )
            langgraph_agent.set_up()
        return langgraph_agent


    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """
        Asynchronously streams the Agent output using Vertex AI Agent Engine.

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during streaming.
        """
        try:
            for chunk in self.agent_executor.stream_query(
                input=input,
                stream_mode="values"
            ):
                # Override the each of the run_ids with the one from the server
                chunk["messages"] = [{**msg, "kwargs": {**msg["kwargs"], "id": input["run_id"]}} for msg in chunk["messages"]]

                if self.agent_engine_resource_id:
                    if chunk["messages"][-1]["id"][-1] == "AIMessage":
                        yield {"agent": chunk}
                else:
                    yield {"agent": chunk}
        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e


class LlamaIndexAgentManager(BaseAgentManager):
    """
    AgentManager subclass for LangGraph Agent orchestration.
    """
    def __init__(
        self,
        prompt: str,
        industry_type: str,
        location: Optional[str] = "us-central1",
        agent_engine_resource_id: Optional[str] = None,
        model_name: Optional[str] = GEMINI_FLASH_20_LATEST,
        max_retries: Optional[int] = 6,
        max_output_tokens: Optional[int] = None,
        temperature: Optional[float] = 0,
        top_p: Optional[float] = None,
        top_k: Optional[int] = None,
        return_steps: Optional[bool] = False,
        verbose: Optional[bool] = True
    ):
        super().__init__(
            prompt=prompt,
            industry_type=industry_type,
            location=location,
            orchestration_framework=OrchestrationFramework.LLAMAINDEX_AGENT.value,
            agent_engine_resource_id=agent_engine_resource_id,
            model_name=model_name,
            max_retries=max_retries,
            max_output_tokens=max_output_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            return_steps=return_steps,
            verbose=verbose
        )
        self.model_obj = LangChainLLM(self.model_obj)


    def get_tools(self):
        """
        Helper method to retrieve tools based on the industry type.

        Returns:
            A list of tools for the agent to use, based on the industry type.
        """
        return get_llamaindex_tools(self.industry_type)


    def create_agent_executor(self):
        """
        Creates a LlamaIndex React Agent executor.
        """
        # setup the index/query llm for Vertex Search
        Settings.llm = LangChainLLM(self.model_obj)
        llamaindex_agent = ReActAgent.from_tools(
            prompt=self.prompt,
            llm = LangChainLLM(self.model_obj),
            tools=self.tools,
            verbose=self.verbose
        )
        return llamaindex_agent


    def get_response_obj(
        self,
        content: str,
        run_id: str
    ) -> Dict:
        """Returns a structure dictionary response object.
        The response object needs to be organized to be
        consistent with other Agents (e.g. Agent Engine)

        Args:
            content: The string reponse from the LLM.
            run_id: The run_id.

        Returns:
            Structured dictionary reponse object.
        """
        response_obj = {
            "agent": {
                "messages": [
                    {
                        "lc": 1,
                        "type": "constructor",
                        "id": ["llamaindex", "schema", "messages", "AIMessage"],
                        "kwargs": {
                            "content": content,
                            "type": "ai",
                            "tool_calls": [],
                            "invalid_tool_calls": [],
                            "id": run_id
                        },
                    }
                ]
            }
        }
        return response_obj


    async def astream(
        self,
        input: Dict[str, Any],
    ) -> AsyncGenerator[Dict, Any]:
        """Asynchronously event streams the Agent output. Needs to be
        an Iterable for Agent Engine deployments

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during streaming.
        """
        try:
            if self.agent_engine_resource_id:
                llamaindex_agent = reasoning_engines.ReasoningEngine(self.agent_engine_resource_id)
                # Uses the stream_query function below
                response = llamaindex_agent.stream_query(input=input)
                for chunk in response:
                    yield chunk
            else:
                # Convert the messages into a string
                content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
                response = await self.agent_executor.aquery(content)
                yield self.get_response_obj(
                    content=response.response,
                    run_id=input["run_id"]
                )

        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e

    # # alias stream_query for Agent Engine deployments
    # stream_query = astream

    def stream_query(
        self,
        input: Dict[str, Any]
    ) -> Generator:
        """Synchronously event streams the Agent output.
        This function is an alias for astream for Agent Engine deployments

        Args:
            input: The list of messages to send to the model as input.

        Yields:
            Dictionaries representing the streamed agent output.

        Exception:
            An error is encountered during processing.
        """
        run_id = uuid.uuid4()

        # If using Agent Engine, explicitly re-set the index/query llm for Vertex Search
        Settings.llm = self.model_obj

        try:
            # Convert the messages into a string
            content = "\n".join(f"[{msg['type']}]: {msg['content']}" for msg in input["messages"])
            response = self.agent_executor.query(content)
            yield self.get_response_obj(
                content=response.response,
                run_id=run_id
            )

        except Exception as e:
            raise RuntimeError(f"Unexpected error. {e}") from e
