import json
from typing import TYPE_CHECKING, List, Literal, Union

from datasets import Dataset, concatenate_datasets
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, GeneratorStep, KeepColumns, Step, StepInput
from distilabel.steps.tasks import TextGeneration
from typing_extensions import override

CHOSEN_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant. 
The final message is of poor quality positively. Your task is to regenerate one of high quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
High quality response:
""".rstrip()

CHOSEN_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate high quality response when other assistants created a poor quality response."

REJECT_TEMPLATE = """
You are provide with a conversation between a human and an AI assistant.
The final message is of high quality positively. Your task is to regenerate one of poor quality.
{% for message in conversation %}
{{ message["role"] }}: {{ message["content"] }}
{% endfor %}
Poor quality response:
""".rstrip()

REJECT_SYSTEM_PROMPT = "You are a helpful AI assistant. Your task is to generate a poor quality response when other assistants created a high quality response."


class FilterConversationRatings(Step):
    """Filters conversations based on the rating of the last message."""

    target_column: Union[Literal["chosen"], Literal["rejected"]]
    batch_size: int = 5

    @override
    def process(self, dataset: StepInput) -> "GeneratorStepOutput":

        column_rating_map = {
            "chosen": 1,
            "rejected": -1,
        }

        target_rating = column_rating_map[self.target_column]

        for batch_start in range(0, len(dataset), self.batch_size):
            batch = dataset[batch_start : batch_start + self.batch_size]
            filtered_batch = []
            for conversation in batch:
                for row in batch:
                    _conversation = row["conversation"]
                    conversation = None
                    for idx, message in enumerate(_conversation, 1):
                        if not isinstance(message["rating"], int):
                            continue
                        if message["rating"] == target_rating:
                            conversation = _conversation[:idx]
                            break
                    if conversation:
                        filtered_batch.append({"conversation": conversation})
            yield filtered_batch

    @property
    def outputs(self) -> "StepColumns":
        return ["conversation"]


class AppendToConversationStep(Step):
    """Appends a generated message to a conversation."""

    @property
    def inputs(self) -> "StepColumns":
        return ["generation", "conversation"]

    @property
    def outputs(self) -> "StepColumns":
        return ["generated_conversation", "conversation"]

    def process(self, inputs: StepInput) -> "StepOutput":

        for input in inputs:
            if not input["generation"]:
                continue
            if not input["conversation"]:
                continue
            input["generated_conversation"] = [
                {"role": message["role"], "content": message["content"]}
                for message in input["conversation"][:-1]
            ] + [{"role": "assistant", "content": input["generation"]}]
            input["conversation"] = [
                {"role": message["role"], "content": message["content"]}
                for message in input["conversation"]
            ]
        yield inputs


with Pipeline(
    name="conversation_rejection",
    description="Generate a chosen response to a rejected conversation.",
) as rejection_pipeline:

    rejected_dataset = FilterConversationRatings(target_column="rejected")

    chosen_text_gen = TextGeneration(
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        ),
        system_prompt=CHOSEN_SYSTEM_PROMPT,
        template=CHOSEN_TEMPLATE,
        columns=["conversation"],
    )

    append_chosen = AppendToConversationStep(
        output_mappings={
            "generated_conversation": "chosen",
            "conversation": "rejected",
        },
    )

    keep_columns = KeepColumns(
        columns=["chosen", "rejected"],
    )

    rejected_dataset >> chosen_text_gen >> append_chosen >> keep_columns

with Pipeline(
    name="conversation_chosen",
    description="Generate a rejected response to a chosen conversation.",
) as chosen_pipeline:

    chosen_dataset = FilterConversationRatings(target_column="chosen")

    rejected_text_gen = TextGeneration(
        llm=InferenceEndpointsLLM(
            model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
        ),
        system_prompt=REJECT_SYSTEM_PROMPT,
        template=REJECT_TEMPLATE,
        columns=["conversation"],
    )
    append_rejected = AppendToConversationStep(
        output_mappings={
            "generated_conversation": "rejected",
            "conversation": "chosen",
        },
    )
    keep_columns = KeepColumns(
        columns=["chosen", "rejected"],
    )
    chosen_dataset >> rejected_text_gen >> append_rejected >> keep_columns

if __name__ == "__main__":

    dataset_path = "example_data.json"
    data = json.load(open(dataset_path))

    dataset = Dataset.from_list(data)
    rejected_dataset = rejection_pipeline.run(dataset=dataset, use_cache=False)
    chosen_dataset = chosen_pipeline.run(dataset=dataset, use_cache=False)

    dataset = concatenate_datasets(
        dsets=[rejected_dataset["default"]["train"], chosen_dataset["default"]["train"]]
    )
