import json
import random
from contextlib import ExitStack
from dataclasses import dataclass
from pathlib import Path
import icu

from pipeline.common.datasets import (
    CountingStep,
    FilteringStep,
    Statistics,
    WeakStringSet,
)
from pipeline.common.downloads import location_exists, read_lines, write_lines
from pipeline.common.logging import get_logger
from pipeline.common.memory import log_memory

logger = get_logger(__name__)

random.seed(38947598475)


@dataclass
class HPLTDocument:
    """
    A structured type for the HPLT document entry in a jsonl file.
    https://hplt-project.org/datasets/v2.0
    """

    def __init__(self, **json):
        self.lang = json["lang"]
        self.doc_scores = json["doc_scores"]
        self.seg_langs = json["seg_langs"]
        # The sentences in the text, which were separated by newlines.
        self.lines = json["text"].split("\n")

    # The list of detected document languages where the first language is most probable.
    # For example: [zho_Hans, zho_Hant, eng_Latn]
    lang: list[str]
    # The list of document scores from web-docs-scorer where the first score is the overall document score (WDS_score) followed by 8 subscores.
    # All the scores are from 0 to 10.
    # See https://github.com/pablop16n/web-docs-scorer/
    # For example, [8.3, 10, 10, 9.9, 10, 10, 10, 4, 0]
    doc_scores: list[float]
    # The detected language for each line (segment).
    # For example: [yue_Hant, zho_Hans, zho_Hans, zho_Hant, unk, ... ]
    seg_langs: list[str]
    # All of the text, split by newlines.
    lines: list[str]


class FilteringStatistics(Statistics):
    """
    Gather statistics about the filtering process.
    """

    def __init__(self, dataset_path: Path) -> None:
        super().__init__(dataset_path)
        self.shards = FilteringStep(
            "How many shards were sampled from. Each shard contains a subset of the "
            "total datasets available.",
        )
        self.visited_lines = FilteringStep(
            "How many lines were visited and kept from the HPLT documents.",
        )
        self.document_count = CountingStep(
            "How many documents were visited. This can help represent data diversity.",
        )
        self.duplicate_lines = CountingStep(
            "Of the collected lines, this counts how many were duplicates and discarded.",
        )
        self.final_lines = CountingStep(
            "How many lines were actually written.",
        )
        self.filtered_doc_locale = CountingStep(
            "How many lines were filtered based on document locale.",
        )
        self.filtered_line_locale = CountingStep(
            "How many lines were filtered based on line locales.",
        )
        self.filtered_doc_score = CountingStep(
            "How many lines were filtered based on document scores.",
        )
        self.filtered_too_long = CountingStep(
            "How many lines were filtered based on length.",
        )

    def count_shards_visited(self, *_args):
        self.shards.filtered -= 1
        self.shards.kept += 1


def get_hplt_locale(lang_iso6931: str) -> str:
    """
    Converts language in ISO-693-1 format to the HPLT format.
    For example, ru -> rus_Cyrl
    """
    # icu return Kore by default which is a mix of Hang and Hani
    if lang_iso6931 == "ko":
        return "kor_Hang"
    locale = icu.Locale(lang_iso6931)
    # add default script
    locale = icu.Locale.addLikelySubtags(locale)
    hplt_locale = f"{locale.getISO3Language()}_{locale.getScript()}"
    return hplt_locale


def get_hplt_map_url(hplt_locale: str) -> str:
    return f"https://data.hplt-project.org/two/cleaned/{hplt_locale}_map.txt"


def language_has_hplt_support(language: str) -> bool:
    hplt_locale = get_hplt_locale(language)
    hplt_map = get_hplt_map_url(hplt_locale)
    return location_exists(hplt_map)


def load_shuffled_shard_urls(hplt_locale: str) -> list[str]:
    """
    Download the list of shards, e.g.
    https://data.hplt-project.org/two/cleaned/rus_Cyrl/1.jsonl.zst
    https://data.hplt-project.org/two/cleaned/rus_Cyrl/2.jsonl.zst
    ...
    https://data.hplt-project.org/two/cleaned/rus_Cyrl/10.jsonl.zst
    """

    url = get_hplt_map_url(hplt_locale)
    logger.info(f"Downloading shard list: {url}")

    with read_lines(url) as lines:
        shard_urls = []
        for line in lines:
            shard_urls.append(line.strip())
    random.Random(url).shuffle(shard_urls)

    logger.info(f"Available shards for {hplt_locale}:")
    for lines in shard_urls:
        logger.info(f" - {lines}")
    return shard_urls


class HpltDownloader:
    """
    Downloads and filters the HPLT dataset.
    https://hplt-project.org/datasets/v2.0

    Parameters:
     - language: The BCP 47 language code to filter the documents.
     - hplt_min_doc_score: The minimum score a document must have to be included in the final dataset.
     - max_characters: The maximum number of characters to merge sentences in the document before writing if enabled.
                       Also filters lines that are too long.
     - max_lines: The maximum number of lines to include in the final dataset.
     - file_destination: The destination path where the final dataset will be written.
     - merge_lines: Whether to accumulate line of the same document in one segment until max_characters is reached.
    """

    def __init__(
        self,
        language: str,
        hplt_min_doc_score: float,
        max_characters: int,
        max_lines: int,
        file_destination: Path,
        merge_lines: bool,
    ) -> None:
        self.merge_lines = merge_lines
        self.max_lines = max_lines
        self.max_characters = max_characters
        self.hplt_min_doc_score = hplt_min_doc_score
        self.hplt_locale = get_hplt_locale(language)
        self.accumulated_text = ""
        self.cumulative_char_count = 0
        self.visited_lines = 0
        self.file_destination = file_destination
        self.stats = FilteringStatistics(file_destination)
        self.strings_seen = WeakStringSet()
        self.stack = ExitStack()
        self.outfile = self.stack.enter_context(write_lines(file_destination))

    def close(self):
        self.stack.close()

    def download(self):
        try:
            self._run_download()
        finally:
            self.close()

    def _run_download(self):
        logger.info(f"Using HPLT locale {self.hplt_locale}")
        shuffled_shard_urls = load_shuffled_shard_urls(self.hplt_locale)
        self.stats.shards.filtered = len(shuffled_shard_urls)

        # The shard URLs are shuffled, and then streamed into the read_lines iterator.
        # This iterator can work over multiple documents. The first document is loaded,
        # and then the documents in the shard are read in order from that shard. After
        # the first shard is read, the iterator continues with the next shards until
        # enough fluent sentences are collected. At this point the remaining shards
        # will not be visited.
        document_stream = self.stack.enter_context(
            read_lines(shuffled_shard_urls, on_enter_location=self.stats.count_shards_visited)
        )

        for document_json in document_stream:
            self.stats.document_count.value += 1
            document = HPLTDocument(**json.loads(document_json))
            overall_doc_score = document.doc_scores[0]
            doc_lang = document.lang[0]

            self._maybe_write_accumulated_text()

            # HPLT 2.0 uses document level scores
            if overall_doc_score < self.hplt_min_doc_score:
                self.stats.filtered_doc_score.value += 1
                continue

            # We want only documents written primarily in the target language
            if doc_lang != self.hplt_locale:
                self.stats.filtered_doc_locale.value += 1
                continue

            # Visit the lines in the document.
            for line_locale, line in zip(document.seg_langs, document.lines):
                self.visited_lines += 1
                self._process_line(line_locale, line)
                if self.visited_lines % 5_000_000 == 0:
                    logger.info(f"Visited {self.visited_lines:,} lines")
                    logger.info(f"Kept {self.stats.visited_lines.kept:,}.")
                    logger.info(
                        f"Wrote {self.stats.final_lines.value:,} out of {self.max_lines:,}."
                    )
                    log_memory()

                if self.stats.final_lines.value == self.max_lines:
                    break

            if self.stats.final_lines.value == self.max_lines:
                break

            self._maybe_write_accumulated_text()

        self.stats.visited_lines.filtered = self.visited_lines - self.stats.visited_lines.kept
        logger.info(f"Wrote {self.stats.final_lines.value:,} lines to: {self.file_destination}")
        stat_path = self.stats.save_json()
        logger.info(f"Saved filtering stats to: {stat_path}")

    def _process_line(self, line_locale: str, line: str):
        # Line locale does not match expected locale, filter
        if line_locale != self.hplt_locale:
            self.stats.filtered_line_locale.value += 1
            self._maybe_write_accumulated_text()
            return

        char_count = len(line)
        # Filter long segments
        if char_count > self.max_characters:
            self.stats.filtered_too_long.value += 1
            self._maybe_write_accumulated_text()
            return

        # Just write the current line if merging is disabled
        if not self.merge_lines:
            self.accumulated_text = line
            self.stats.visited_lines.kept += 1
            self._maybe_write_accumulated_text()
            return

        # Text accumulation mode starts here

        self.stats.visited_lines.kept += 1

        # Determine if this sentence should be added to the previous one or
        # written out as a new line.
        if self.cumulative_char_count + char_count + 1 > self.max_characters:
            # This line would be too long, write it out.
            self._maybe_write_accumulated_text()

        self.cumulative_char_count += char_count
        # Collect this line to write.
        if self.accumulated_text:
            self.accumulated_text = f"{self.accumulated_text} {line}"
            # count the whitespace
            self.cumulative_char_count += 1
        else:
            self.accumulated_text = line

    def _maybe_write_accumulated_text(self):
        """
        Since the loop below is building up paragraphs of text, we only want to write
        out a line when enough text has been accumulated. The paragraph should be
        written out when either the text gets too long, or the next line is discarded.
        """

        self.cumulative_char_count = 0
        if self.accumulated_text:
            if self.accumulated_text in self.strings_seen:
                self.stats.duplicate_lines.value += 1
            else:
                self.outfile.write(self.accumulated_text + "\n")
                self.stats.final_lines.value += 1
                self.strings_seen.add(self.accumulated_text)
            self.accumulated_text = ""
