# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import asyncio
import base64
import json
import logging
import wave
from typing import Literal, Optional
from urllib.parse import urlencode

import websockets

from pyrit.models import PromptRequestResponse
from pyrit.models.data_type_serializer import data_serializer_factory
from pyrit.models.prompt_request_response import construct_response_from_request
from pyrit.prompt_target import OpenAITarget, limit_requests_per_minute

logger = logging.getLogger(__name__)

RealTimeVoice = Literal["alloy", "echo", "shimmer"]


class RealtimeTarget(OpenAITarget):

    def __init__(
        self,
        *,
        api_version: str = "2024-10-01-preview",
        system_prompt: Optional[str] = "You are a helpful AI assistant",
        voice: Optional[RealTimeVoice] = None,
        existing_convo: Optional[dict] = {},
        **kwargs,
    ) -> None:
        """
        RealtimeTarget class for Azure OpenAI Realtime API.
        Read more at https://learn.microsoft.com/en-us/azure/ai-services/openai/realtime-audio-reference
            and https://platform.openai.com/docs/guides/realtime-websocket
        Args:
            model_name (str, Optional): The name of the model.
            endpoint (str, Optional): The target URL for the OpenAI service.
            api_key (str, Optional): The API key for accessing the Azure OpenAI service.
                Defaults to the OPENAI_CHAT_KEY environment variable.
            headers (str, Optional): Headers of the endpoint (JSON).
            use_aad_auth (bool, Optional): When set to True, user authentication is used
                instead of API Key. DefaultAzureCredential is taken for
                https://cognitiveservices.azure.com/.default . Please run `az login` locally
                to leverage user AuthN.
            api_version (str, Optional): The version of the Azure OpenAI API. Defaults to
                "2024-06-01".
            max_requests_per_minute (int, Optional): Number of requests the target can handle per
                minute before hitting a rate limit. The number of requests sent to the target
                will be capped at the value provided.
            api_version (str, Optional): The version of the Azure OpenAI API. Defaults to "2024-10-01-preview".
            system_prompt (str, Optional): The system prompt to use. Defaults to "You are a helpful AI assistant".
            voice (literal str, Optional): The voice to use. Defaults to None.
                the only supported voices by the AzureOpenAI Realtime API are "alloy", "echo", and "shimmer".
            existing_convo (dict[str, websockets.WebSocketClientProtocol], Optional): Existing conversations.
            httpx_client_kwargs (dict, Optional): Additional kwargs to be passed to the
                httpx.AsyncClient() constructor.
                For example, to specify a 3 minutes timeout: httpx_client_kwargs={"timeout": 180}
        """

        super().__init__(api_version=api_version, **kwargs)

        self.system_prompt = system_prompt
        self.voice = voice
        self._existing_conversation = existing_convo

    def _set_openai_env_configuration_vars(self):
        self.model_name_environment_variable = "OPENAI_REALTIME_MODEL"
        self.endpoint_environment_variable = "AZURE_OPENAI_REALTIME_ENDPOINT"
        self.api_key_environment_variable = "OPENAI_REALTIME_API_KEY"

    async def connect(self):
        """
        Connects to Realtime API Target using websockets.
        Returns the WebSocket connection.
        """

        logger.info(f"Connecting to WebSocket: {self._endpoint}")

        query_params = {
            "deployment": self._model_name,
            "OpenAI-Beta": "realtime=v1",
        }

        self._add_auth_param_to_query_params(query_params)

        if self._api_version is not None:
            query_params["api-version"] = self._api_version

        url = f"{self._endpoint}?{urlencode(query_params)}"

        websocket = await websockets.connect(url)
        logger.info("Successfully connected to AzureOpenAI Realtime API")
        return websocket

    def _add_auth_param_to_query_params(self, query_params: dict) -> None:
        """
        Adds the authentication parameter to the query parameters. This is how
        Realtime API works, it doesn't use the headers for auth.

        Args:
            query_params (dict): The query parameters.
        """
        if self._api_key:
            query_params["api-key"] = self._api_key

        if self._azure_auth:
            query_params["access_token"] = self._azure_auth.refresh_token()

    def _set_system_prompt_and_config_vars(self):

        session_config = {
            "modalities": ["audio", "text"],
            "instructions": self.system_prompt,
            "input_audio_format": "pcm16",
            "output_audio_format": "pcm16",
            "turn_detection": None,
        }

        if self.voice:
            session_config["voice"] = self.voice

        return session_config

    async def send_event(self, event: dict, conversation_id: str):
        """
        Sends an event to the WebSocket server.

        Args:
            event (dict): Event to send in dictionary format.
            conversation_id (str): Conversation ID
        """
        websocket = self._existing_conversation.get(conversation_id)

        if websocket is None:
            logger.error("WebSocket connection is not established")
            raise Exception("WebSocket connection is not established")
        await websocket.send(json.dumps(event))
        logger.debug(f"Event sent - type: {event['type']}")

    async def send_config(self, conversation_id: str):
        """
        Sends the session configuration to the WebSocket server.

        Args:
            conversation_id (str): Conversation ID
        """

        config_variables = self._set_system_prompt_and_config_vars()

        await self.send_event(
            event={"type": "session.update", "session": config_variables}, conversation_id=conversation_id
        )
        logger.info("Session set up")

    @limit_requests_per_minute
    async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:

        convo_id = prompt_request.request_pieces[0].conversation_id
        if convo_id not in self._existing_conversation:
            websocket = await self.connect()
            self._existing_conversation[convo_id] = websocket

            self.set_system_prompt(
                system_prompt=self.system_prompt,
                conversation_id=convo_id,
                orchestrator_identifier=self.get_identifier(),
            )

        websocket = self._existing_conversation[convo_id]

        self._validate_request(prompt_request=prompt_request)

        await self.send_config(conversation_id=convo_id)
        request = prompt_request.request_pieces[0]
        response_type = request.converted_value_data_type

        # Order of messages sent varies based on the data format of the prompt
        if response_type == "audio_path":
            output_audio_path, events = await self.send_audio_async(
                filename=request.converted_value, conversation_id=convo_id
            )

        elif response_type == "text":
            output_audio_path, events = await self.send_text_async(
                text=request.converted_value, conversation_id=convo_id
            )

        text_response_piece = construct_response_from_request(
            request=request, response_text_pieces=[events[1]], response_type="text"
        ).request_pieces[0]

        audio_response_piece = construct_response_from_request(
            request=request, response_text_pieces=[output_audio_path], response_type="audio_path"
        ).request_pieces[0]

        response_entry = PromptRequestResponse(request_pieces=[text_response_piece, audio_response_piece])
        return response_entry

    async def save_audio(
        self,
        audio_bytes: bytes,
        num_channels: int = 1,
        sample_width: int = 2,
        sample_rate: int = 16000,
        output_filename: str = None,
    ) -> str:
        """
        Saves audio bytes to a WAV file.

        Args:
            audio_bytes (bytes): Audio bytes to save.
            num_channels (int): Number of audio channels. Defaults to 1 for the PCM16 format
            sample_width (int): Sample width in bytes. Defaults to 2 for the PCM16 format
            sample_rate (int): Sample rate in Hz. Defaults to 16000 Hz for the PCM16 format
            output_filename (str): Output filename. If None, a UUID filename will be used.

        Returns:
            str: The path to the saved audio file.
        """
        data = data_serializer_factory(category="prompt-memory-entries", data_type="audio_path")

        await data.save_formatted_audio(
            data=audio_bytes,
            output_filename=output_filename,
            num_channels=num_channels,
            sample_width=sample_width,
            sample_rate=sample_rate,
        )

        return data.value

    async def cleanup_target(self):
        """
        Disconnects from the WebSocket server to clean up, cleaning up all existing conversations.
        """
        for conversation_id, websocket in self._existing_conversation.items():
            if websocket:
                await websocket.close()
                logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
        self._existing_conversation = {}

    async def cleanup_conversation(self, conversation_id: str):
        """
        Disconnects from the WebSocket server for a specific conversation
        """
        websocket = self._existing_conversation.get(conversation_id)
        if websocket:
            await websocket.close()
            logger.info(f"Disconnected from {self._endpoint} with conversation ID: {conversation_id}")
            del self._existing_conversation[conversation_id]

    async def send_response_create(self, conversation_id: str):
        """
        Sends response.create message to the WebSocket server.
        """
        await self.send_event(event={"type": "response.create"}, conversation_id=conversation_id)

    async def receive_events(self, conversation_id: str) -> list:
        """
        Continuously receive events from the WebSocket server.
        Args:
            conversation_id: conversation ID
        """
        websocket = self._existing_conversation[conversation_id]
        if websocket is None:  # change this to existing_conversation.websocket
            logger.error("WebSocket connection is not established")
            raise Exception("WebSocket connection is not established")

        audio_transcript = None
        audio_buffer = b""
        conversation_messages = []
        try:
            async for message in websocket:
                event = json.loads(message)
                msg_response_type = event.get("type")
                if msg_response_type:
                    if msg_response_type == "response.done":
                        logger.debug(f"event is: {json.dumps(event, indent=2)}")
                        audio_transcript = event["response"]["output"][0]["content"][0]["transcript"]
                        conversation_messages.append(audio_transcript)
                        break
                    elif msg_response_type == "error":
                        logger.error(f"Error, event is: {json.dumps(event, indent=2)}")
                        break
                    elif msg_response_type == "response.audio.delta":
                        # Append audio data to buffer
                        audio_data = base64.b64decode(event["delta"])
                        audio_buffer += audio_data
                        logger.debug("Audio data appended to buffer")
                    elif msg_response_type == "response.audio.done":
                        logger.debug(f"event is: {json.dumps(event, indent=2)}")
                        conversation_messages.append(audio_buffer)
                    else:
                        logger.debug(f"event is: {json.dumps(event, indent=2)}")

        except websockets.ConnectionClosed as e:
            logger.error(f"WebSocket connection closed: {e}")
        except Exception as e:
            logger.error(f"An unexpected error occurred: {e}")
        return conversation_messages

    async def send_text_async(self, text: str, conversation_id: str):
        """
        Sends text prompt to the WebSocket server.
        Args:
            text: prompt to send.
            conversation_id: conversation ID
        """
        await self.send_response_create(conversation_id=conversation_id)

        # Listen for responses
        receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))

        logger.info(f"Sending text message: {text}")
        event = {
            "type": "conversation.item.create",
            "item": {"type": "message", "role": "user", "content": [{"type": "input_text", "text": text}]},
        }
        await self.send_event(event=event, conversation_id=conversation_id)

        events = await receive_tasks  # Wait for all responses to be received

        output_audio_path = await self.save_audio(events[0])
        return output_audio_path, events

    async def send_audio_async(self, filename: str, conversation_id: str):
        """
        Send an audio message to the WebSocket server.

        Args:
            filename (str): The path to the audio file.
        """
        with wave.open(filename, "rb") as wav_file:

            # Read WAV parameters
            num_channels = wav_file.getnchannels()
            sample_width = wav_file.getsampwidth()  # Should be 2 bytes for PCM16
            frame_rate = wav_file.getframerate()
            num_frames = wav_file.getnframes()

            audio_content = wav_file.readframes(num_frames)

        receive_tasks = asyncio.create_task(self.receive_events(conversation_id=conversation_id))

        try:
            audio_base64 = base64.b64encode(audio_content).decode("utf-8")

            event = {"type": "input_audio_buffer.append", "audio": audio_base64}

            # await asyncio.sleep(0.1)
            await self.send_event(event=event, conversation_id=conversation_id)

        except Exception as e:
            logger.info(f"Error sending audio: {e}")
            return
        event = {"type": "input_audio_buffer.commit"}
        await asyncio.sleep(0.1)
        await self.send_event(event, conversation_id=conversation_id)
        await self.send_response_create(conversation_id=conversation_id)  # Sends response.create message

        responses = await receive_tasks
        output_audio_path = await self.save_audio(responses[0], num_channels, sample_width, frame_rate)
        return output_audio_path, responses

    def _validate_request(self, *, prompt_request: PromptRequestResponse) -> None:
        """Validates the structure and content of a prompt request for compatibility of this target.

        Args:
            prompt_request (PromptRequestResponse): The prompt request response object.

        Raises:
            ValueError: If more than two request pieces are provided.
            ValueError: If any of the request pieces have a data type other than 'text' or 'audio_path'.
        """

        # Check the number of request pieces
        if len(prompt_request.request_pieces) != 1:
            raise ValueError("This target only supports one request piece.")

        if prompt_request.request_pieces[0].converted_value_data_type not in ["text", "audio_path"]:
            raise ValueError("This target only supports text and audio_path prompt input.")

    def is_json_response_supported(self) -> bool:
        """Indicates that this target supports JSON response format."""
        return False
