import asyncio
import json
import logging
import os
from pathlib import Path

from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from google.adk.agents.llm_agent import LlmAgent
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.runners import Runner
from google.adk.sessions import InMemorySessionService
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, StdioServerParameters
from google.genai import types
from starlette.websockets import WebSocketDisconnect

load_dotenv()

APP_NAME = "ADK MCP example"
session_service = InMemorySessionService()
artifacts_service = InMemoryArtifactService()


async def get_tools_async(server_params):
    """Gets tools from MCP Server."""
    tools, exit_stack = await MCPToolset.from_server(connection_params=server_params)
    # MCP requires maintaining a connection to the local MCP Server.
    # Using exit_stack to clean up server connection before exit.
    return tools, exit_stack


async def get_agent_async(server_params):
    """Creates an ADK Agent with tools from MCP Server."""
    tools, exit_stack = await get_tools_async(server_params)
    root_agent = LlmAgent(
        model="gemini-2.5-pro-preview-03-25",
        name="ai_assistant",
        instruction="You're a helpful assistant. Use tools to get information to answer user questions, please format your answer in markdown format.",
        tools=tools,
    )
    return root_agent, exit_stack


ct_server_params = StdioServerParameters(
    command="python",
    args=["./mcp_server/cocktail.py"],
)


async def run_agent(server_params, session_id, question):
    query = question
    print("[user]: ", query)
    content = types.Content(role="user", parts=[types.Part(text=query)])
    root_agent, exit_stack = await get_agent_async(server_params)
    runner = Runner(
        app_name=APP_NAME,
        agent=root_agent,
        artifact_service=artifacts_service,
        session_service=session_service,
    )
    events_async = runner.run_async(
        session_id=session_id, user_id=session_id, new_message=content
    )

    response = []
    async for event in events_async:
        if event.content.role == "model" and event.content.parts[0].text:
            print("[agent]:", event.content.parts[0].text)
            response.append(event.content.parts[0].text)

    await exit_stack.aclose()
    return response


async def run_adk_agent_async(websocket, server_params, session_id):
    """Client to agent communication"""
    try:
        # Your existing setup for the agent might be here
        logging.info(f"Agent task started for session {session_id}")
        while True:
            text = await websocket.receive_text()
            response = await run_agent(server_params, session_id, text)
            if not response:
                continue
            # Send the text to the client
            ai_message = "\n".join(response)
            await websocket.send_text(json.dumps({"message": ai_message}))
            await asyncio.sleep(0)

    except WebSocketDisconnect:
        # This block executes when the client disconnects
        logging.info(f"Client {session_id} disconnected.")
    except Exception as e:
        # Catch other potential errors in your agent logic
        logging.error(
            f"Error in agent task for session {session_id}: {e}", exc_info=True
        )
    finally:
        logging.info(f"Agent task ending for session {session_id}")


# FastAPI web app

app = FastAPI()

STATIC_DIR = Path("static")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")


@app.get("/")
async def root():
    """Serves the index.html"""
    return FileResponse(os.path.join(STATIC_DIR, "index.html"))


@app.websocket("/ws/{session_id}")
async def websocket_endpoint(websocket: WebSocket, session_id: int):
    """Client websocket endpoint"""

    # Wait for client connection
    await websocket.accept()
    print(f"Client #{session_id} connected")

    # Start agent session
    session_id = str(session_id)
    session = session_service.create_session(
        app_name=APP_NAME, user_id=session_id, session_id=session_id, state={}
    )

    # Start tasks
    agent_task = asyncio.create_task(
        run_adk_agent_async(websocket, ct_server_params, session_id)
    )

    await asyncio.gather(agent_task)

    # Disconnected
    print(f"Client #{session_id} disconnected")
