"""
DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs
Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner
https://arxiv.org/abs/1903.00161
"""

import gzip
import json
import random
import re
import string
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import numpy as np
from scipy.optimize import linear_sum_assignment

from . import common
from .common import ANSWER_PATTERN, HTML_JINJA
from .types import Eval, EvalResult, SamplerBase, SingleEvalResult

"""
From here through _normalize_answer was originally copied from:
https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
Then cleaned up and modified a bit.

The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc
/eval/drop_eval.py
"""


def _remove_articles(text: str) -> str:
    regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
    return re.sub(regex, " ", text)


def _white_space_fix(text: str) -> str:
    return " ".join(text.split())


EXCLUDE = set(string.punctuation)


def _remove_punc(text: str) -> str:
    if not _is_number(text):
        return "".join(ch for ch in text if ch not in EXCLUDE)
    else:
        return text


def _lower(text: str) -> str:
    return text.lower()


def _tokenize(text: str) -> List[str]:
    return re.split(" |-", text)


def _normalize_answer(text: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""

    parts = [
        _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
        for token in _tokenize(text)
    ]
    parts = [part for part in parts if part.strip()]
    normalized = " ".join(parts).strip()
    return normalized


def _is_number(text: str) -> bool:
    try:
        float(text)
        return True
    except ValueError:
        return False


def _normalize_number(text: str) -> str:
    if _is_number(text):
        return str(float(text))
    else:
        return text


def _answer_to_bags(
    answer: Union[str, List[str], Tuple[str, ...]]
) -> Tuple[List[str], List[Set[str]]]:
    if isinstance(answer, (list, tuple)):
        raw_spans = answer
    else:
        raw_spans = [answer]
    normalized_spans: List[str] = []
    token_bags = []
    for raw_span in raw_spans:
        normalized_span = _normalize_answer(raw_span)
        normalized_spans.append(normalized_span)
        token_bags.append(set(normalized_span.split()))
    return normalized_spans, token_bags


def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
    """
    Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
    between them and gets maximum metric values over all the answers.
    """
    scores = np.zeros([len(gold), len(predicted)])
    for gold_index, gold_item in enumerate(gold):
        for pred_index, pred_item in enumerate(predicted):
            if _match_numbers_if_present(gold_item, pred_item):
                scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
    row_ind, col_ind = linear_sum_assignment(-scores)

    max_scores = np.zeros([max(len(gold), len(predicted))])
    for row, column in zip(row_ind, col_ind):
        max_scores[row] = max(max_scores[row], scores[row, column])
    return max_scores


def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
    intersection = len(gold_bag.intersection(predicted_bag))
    if not predicted_bag:
        precision = 1.0
    else:
        precision = intersection / float(len(predicted_bag))
    if not gold_bag:
        recall = 1.0
    else:
        recall = intersection / float(len(gold_bag))
    f1 = (
        (2 * precision * recall) / (precision + recall)
        if not (precision == 0.0 and recall == 0.0)
        else 0.0
    ) * 100
    return f1


def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
    gold_numbers = set()
    predicted_numbers = set()
    for word in gold_bag:
        if _is_number(word):
            gold_numbers.add(word)
    for word in predicted_bag:
        if _is_number(word):
            predicted_numbers.add(word)
    if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
        return True
    return False


def get_drop_metrics(
    predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]]
) -> Tuple[float, float]:
    """
    Takes a predicted answer and a gold answer (that are both either a string or a list of
    strings), and returns exact match and the DROP F1 metric for the prediction.  If you are
    writing a script for evaluating objects in memory (say, the output of predictions during
    validation, or while training), this is the function you want to call, after using
    :func:`answer_json_to_strings` when reading the gold answer from the released data file.
    """
    predicted_bags = _answer_to_bags(predicted)
    gold_bags = _answer_to_bags(gold)

    if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
        exact_match = 1.0
    else:
        exact_match = 0.0

    f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
    f1 = np.mean(f1_per_bag)
    f1 = round(f1, 2)
    return exact_match, f1


def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
    """
    Takes an answer JSON blob from the DROP data release and converts it into strings used for
    evaluation.
    """
    if "number" in answer and answer["number"]:
        return tuple([str(answer["number"])]), "number"
    elif "spans" in answer and answer["spans"]:
        return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
    elif "date" in answer:
        return (
            tuple(
                [
                    "{0} {1} {2}".format(
                        answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
                    ).strip()
                ]
            ),
            "date",
        )
    else:
        raise ValueError(
            f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}"
        )


def answer_json_to_string(answer_json):
    return json.dumps(answer_json_to_strings(answer_json))


def normalize(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    s = s.lower()
    exclude = set(string.punctuation)
    s = "".join(char for char in s if char not in exclude)
    s = re.sub(r"\b(a|an|the)\b", " ", s)
    s = " ".join(s.split())
    return s


def fuzzy_match(s1: str, s2: str) -> bool:
    s1 = normalize(s1)
    s2 = normalize(s2)

    if s1 == "" or s2 == "":
        return s1 == s2

    return s1 in s2 or s2 in s1


def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]:
    em_scores = []
    f1_scores = []
    for answer in reference:
        if answer.strip() != "":
            em, f1 = get_drop_metrics(sample, answer)
            em_scores.append(em)
            f1_scores.append(f1)
    return (max(em_scores), max(f1_scores))


class DropEval(Eval):
    def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3):
        self.seed = 42
        self._num_examples = num_examples
        self._train_samples_per_prompt = train_samples_per_prompt
        self.train_jsonl = (
            "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz"
        )
        self.test_jsonl = (
            "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz"
        )
        with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f:
            self.train_samples = list(map(json.loads, f.readlines()))
        with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f:
            self.test_samples = list(map(json.loads, f.readlines()))
            if self._num_examples:
                self.test_samples = random.Random(self.seed).sample(
                    self.test_samples, self._num_examples
                )

    def __call__(self, sampler: SamplerBase) -> EvalResult:
        rng = random.Random(self.seed)

        def fn(example: dict[str, str]):
            stuffing = rng.sample(self.train_samples, self._train_samples_per_prompt)

            # prompt = """TASK: Read the provided passage, then identify the correct answer to questions below."""
            prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below."""
            prompt += "\n\n# Examples"
            samples = stuffing + [example]
            for i, sample in enumerate(samples):
                is_test = i == len(stuffing)
                prompt += "\n# Your Task\n" if is_test else ""
                prompt += f"""
---
{sample["context"]} """

                a = sample["completion"]
                correct_answers = sample["ref_text"].split("|")

                if not is_test:
                    prompt += a + "\n"
                else:
                    prompt += """\n
Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response.
                    """
                    prompt_messages = [sampler._pack_message(content=prompt, role="user")]
                    response_text = sampler(prompt_messages)
                    match = re.search(ANSWER_PATTERN, response_text)
                    extracted_answer = match.group(1) if match else response_text
                    em_score, f1_score = drop_metric(extracted_answer, correct_answers)
                    matches = [
                        fuzzy_match(extracted_answer, correct_answer)
                        for correct_answer in correct_answers
                    ]
                    extracted_answers = [
                        extracted_answer for i in range(len(correct_answers)) if matches[i]
                    ]
                    score = True in matches
                    html = common.jinja_env.from_string(HTML_JINJA).render(
                        prompt_messages=prompt_messages,
                        next_message=dict(content=extracted_answer, role="assistant"),
                        score=score,
                        correct_answer=correct_answers,
                        extracted_answer=extracted_answers,
                    )
                    convo = prompt_messages + [dict(content=extracted_answer, role="assistant")]
                    return SingleEvalResult(
                        html=html,
                        score=score,
                        convo=convo,
                        metrics={"em_score": em_score, "f1_score": f1_score},
                    )

        results = common.map_with_progress(fn, self.test_samples)
        return common.aggregate_results(results)
