# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import json
import logging
import pathlib
import uuid
from textwrap import dedent

from pyrit.common.path import DATASETS_PATH
from pyrit.exceptions import (
    InvalidJsonException,
    pyrit_json_retry,
    remove_markdown_json,
)
from pyrit.models import (
    PromptDataType,
    PromptRequestPiece,
    PromptRequestResponse,
    SeedPrompt,
)
from pyrit.prompt_converter import ConverterResult, PromptConverter
from pyrit.prompt_target import PromptChatTarget

logger = logging.getLogger(__name__)


class VariationConverter(PromptConverter):
    def __init__(self, *, converter_target: PromptChatTarget, prompt_template: SeedPrompt = None):
        self.converter_target = converter_target

        # set to default strategy if not provided
        prompt_template = (
            prompt_template
            if prompt_template
            else SeedPrompt.from_yaml_file(
                pathlib.Path(DATASETS_PATH) / "prompt_converters" / "variation_converter.yaml"
            )
        )

        self.number_variations = 1

        self.system_prompt = str(prompt_template.render_template_value(number_iterations=str(self.number_variations)))

    async def convert_async(self, *, prompt: str, input_type: PromptDataType = "text") -> ConverterResult:
        """
        Generates variations of the input prompts using the converter target.
        Parameters:
            prompts: list of prompts to convert
        Return:
            target_responses: list of prompt variations generated by the converter target
        """
        if not self.input_supported(input_type):
            raise ValueError("Input type not supported")

        conversation_id = str(uuid.uuid4())

        self.converter_target.set_system_prompt(
            system_prompt=self.system_prompt,
            conversation_id=conversation_id,
            orchestrator_identifier=None,
        )

        prompt = dedent(
            f"Create {self.number_variations} variation of the seed prompt given by the user between the "
            "begin and end tags"
            "=== begin ==="
            f"{prompt}"
            "=== end ==="
        )

        request = PromptRequestResponse(
            [
                PromptRequestPiece(
                    role="user",
                    original_value=prompt,
                    converted_value=prompt,
                    conversation_id=conversation_id,
                    sequence=1,
                    prompt_target_identifier=self.converter_target.get_identifier(),
                    original_value_data_type=input_type,
                    converted_value_data_type=input_type,
                    converter_identifiers=[self.get_identifier()],
                )
            ]
        )
        response_msg = await self.send_variation_prompt_async(request)

        return ConverterResult(output_text=response_msg, output_type="text")

    @pyrit_json_retry
    async def send_variation_prompt_async(self, request):
        response = await self.converter_target.send_prompt_async(prompt_request=request)

        response_msg = response.get_value()
        response_msg = remove_markdown_json(response_msg)
        try:
            response = json.loads(response_msg)

        except json.JSONDecodeError:
            raise InvalidJsonException(message=f"Invalid JSON response: {response_msg}")

        try:
            return response[0]
        except KeyError:
            raise InvalidJsonException(message=f"Invalid JSON response: {response_msg}")

    def input_supported(self, input_type: PromptDataType) -> bool:
        return input_type == "text"

    def output_supported(self, output_type: PromptDataType) -> bool:
        return output_type == "text"
