#!/usr/bin/env python3
"""
Downloads a parallel dataset and runs augmentation if needed

Example:
    python pipeline/data/parallel_importer.py \
        --dataset=sacrebleu_aug-mix_wmt19 \
        --output_prefix=$(pwd)/test_data/augtest \
        --src=ru \
        --trg=en
"""

import argparse
import os
import random
import re
import sys
from pathlib import Path
from typing import Dict, Iterable, List

from opustrainer.modifiers.noise import NoiseModifier
from opustrainer.modifiers.placeholders import PlaceholderTagModifier
from opustrainer.modifiers.surface import TitleCaseModifier, UpperCaseModifier
from opustrainer.modifiers.typos import TypoModifier
from opustrainer.types import Modifier

from pipeline.common.downloads import compress_file, decompress_file
from pipeline.common.logging import get_logger
from pipeline.data.cjk import handle_chinese_parallel, ChineseType


from pipeline.data.parallel_downloaders import download, Downloader

random.seed(1111)

logger = get_logger(__file__)


class CompositeModifier:
    """
    Composite modifier runs several modifiers one after another
    """

    def __init__(self, modifiers: List[Modifier]):
        self._modifiers = modifiers

    def __call__(self, batch: List[str]) -> Iterable[str]:
        for mod in self._modifiers:
            batch = list(mod(batch))

        return batch


MIX_PROB = 0.05  # 5% will be augmented in the mix
PROB_1 = 1.0  # 100% chance
PROB_0 = 0.0  # 0% chance
# use lower probabilities than 1 to add inline noise into the mix
# probability 1 adds way too much noise to a corpus
NOISE_PROB = 0.05
NOISE_MIX_PROB = 0.01


def get_typos_probs() -> Dict[str, float]:
    # select 4 random types of typos
    typos = set(random.sample(list(TypoModifier.modifiers.keys()), k=4))
    # set probability 1 for selected typos and 0 for the rest
    probs = {typo: PROB_1 if typo in typos else PROB_0 for typo in TypoModifier.modifiers.keys()}
    return probs


# See documentation for the modifiers in https://github.com/mozilla/translations/blob/main/docs/training/opus-trainer.md#supported-modifiers
modifier_map = {
    "aug-typos": lambda: TypoModifier(PROB_1, **get_typos_probs()),
    "aug-title": lambda: TitleCaseModifier(PROB_1),
    "aug-upper": lambda: UpperCaseModifier(PROB_1),
    "aug-noise": lambda: NoiseModifier(PROB_1),
    "aug-inline-noise": lambda: PlaceholderTagModifier(NOISE_PROB, augment=1),
    "aug-mix": lambda: CompositeModifier(
        [
            TypoModifier(MIX_PROB, **get_typos_probs()),
            TitleCaseModifier(MIX_PROB),
            UpperCaseModifier(MIX_PROB),
            NoiseModifier(MIX_PROB),
            PlaceholderTagModifier(NOISE_MIX_PROB, augment=1),
        ]
    ),
    "aug-mix-cjk": lambda: CompositeModifier(
        [
            NoiseModifier(MIX_PROB),
            PlaceholderTagModifier(NOISE_MIX_PROB, augment=1),
        ]
    ),
}


def add_alignments(corpus: List[str]) -> List[str]:
    from simalign import SentenceAligner  # type: ignore

    # We use unsupervised aligner here because statistical tools like fast_align require a large corpus to train on
    # This is slow without a GPU and is meant to operate only on small evaluation datasets

    # Use BERT with subwords and itermax as it has a higher recall and matches more words than other methods
    # See more details in the paper: https://arxiv.org/pdf/2004.08728.pdf
    # and in the source code: https://github.com/cisnlp/simalign/blob/master/simalign/simalign.py
    # This will download a 700Mb BERT model from Hugging Face and cache it
    aligner = SentenceAligner(model="bert", token_type="bpe", matching_methods="i")

    alignments = []
    for line in corpus:
        src_sent, trg_sent = line.split("\t")
        sent_aln = aligner.get_word_aligns(src_sent, trg_sent)["itermax"]
        aln_str = " ".join(f"{src_pos}-{trg_pos}" for src_pos, trg_pos in sent_aln)
        alignments.append(aln_str)

    corpus_tsv = [f"{sents}\t{aln}" for sents, aln in zip(corpus, alignments)]
    return corpus_tsv


# we plan to use it only for small evaluation datasets
def augment(output_prefix: str, aug_modifer: str, src: str, trg: str):
    """
    Augment corpus on disk using the OpusTrainer modifier
    """
    if aug_modifer not in modifier_map:
        raise ValueError(f"Invalid modifier {aug_modifer}. Allowed values: {modifier_map.keys()}")

    # file paths for compressed and uncompressed corpus
    uncompressed_src = f"{output_prefix}.{src}"
    uncompressed_trg = f"{output_prefix}.{trg}"
    compressed_src = f"{output_prefix}.{src}.zst"
    compressed_trg = f"{output_prefix}.{trg}.zst"

    corpus = read_corpus_tsv(compressed_src, compressed_trg, uncompressed_src, uncompressed_trg)

    if aug_modifer in ("aug-mix", "aug-inline-noise", "aug-mix-cjk"):
        # add alignments for inline noise
        # Tags modifier will remove them after processing
        corpus = add_alignments(corpus)

    modified = []
    for line in corpus:
        # recreate modifier for each line to apply randomization (for typos)
        modifier = modifier_map[aug_modifer]()
        modified += modifier([line])
    write_modified(modified, uncompressed_src, uncompressed_trg)


def read_corpus_tsv(
    compressed_src: str, compressed_trg: str, uncompressed_src: str, uncompressed_trg: str
) -> List[str]:
    """
    Decompress corpus and read to TSV
    """
    if os.path.isfile(uncompressed_src):
        os.remove(uncompressed_src)
    if os.path.isfile(uncompressed_trg):
        os.remove(uncompressed_trg)

    # Decompress the original corpus.
    decompress_file(compressed_src, keep_original=False)
    decompress_file(compressed_trg, keep_original=False)

    # Since this is only used on small evaluation sets, it's fine to load the entire dataset
    # and augmentation into memory rather than streaming it.
    with open(uncompressed_src) as f:
        corpus_src = [line.rstrip("\n") for line in f]
    with open(uncompressed_trg) as f:
        corpus_trg = [line.rstrip("\n") for line in f]

    corpus_tsv = [f"{src_sent}\t{trg_sent}" for src_sent, trg_sent in zip(corpus_src, corpus_trg)]
    return corpus_tsv


def write_modified(modified: List[str], uncompressed_src: str, uncompressed_trg: str):
    """
    Split the modified TSV corpus, write back and compress
    """
    modified_src = "\n".join([line.split("\t")[0] for line in modified]) + "\n"
    modified_trg = "\n".join([line.split("\t")[1] for line in modified]) + "\n"

    with open(uncompressed_src, "w") as f:
        f.write(modified_src)
    with open(uncompressed_trg, "w") as f:
        f.writelines(modified_trg)

    # compress corpus back
    compress_file(uncompressed_src, keep_original=False)
    compress_file(uncompressed_trg, keep_original=False)


def run_import(
    type: str,
    dataset: str,
    output_prefix: str,
    src: str,
    trg: str,
):
    # Parse a dataset identifier to extract importer, augmentation type and dataset name
    # Examples:
    # opus_wikimedia/v20230407
    # opus_ELRC_2922/v1
    # mtdata_EU-eac_forms-1-eng-lit
    # flores_aug-title_devtest
    # sacrebleu_aug-upper-strict_wmt19
    match = re.search(r"^([a-z]*)_(aug[a-z\-]*)?_?(.+)$", dataset)

    if not match:
        raise ValueError(
            f"Invalid dataset name: {dataset}. "
            f"Use the following format: <importer>_<name> or <importer>_<augmentation>_<name>."
        )

    importer = match.group(1)
    aug_modifer = match.group(2)
    name = match.group(3)

    download(Downloader(importer), src, trg, name, Path(output_prefix))

    # TODO: convert everything to Chinese simplified for now when Chinese is the source language
    # TODO: https://github.com/mozilla/firefox-translations-training/issues/896
    if "zh" in (src, trg):
        handle_chinese_parallel(output_prefix, src=src, trg=trg, variant=ChineseType.simplified)

    if aug_modifer:
        logger.info("Running augmentation")
        augment(output_prefix, aug_modifer, src=src, trg=trg)


def main() -> None:
    logger.info(f"Running with arguments: {sys.argv}")
    parser = argparse.ArgumentParser(
        description=__doc__,
        # Preserves whitespace in the help text.
        formatter_class=argparse.RawTextHelpFormatter,
    )

    parser.add_argument("--type", metavar="TYPE", type=str, help="Dataset type: mono or corpus")
    parser.add_argument(
        "--src",
        metavar="SRC",
        type=str,
        help="Source language",
    )
    parser.add_argument(
        "--trg",
        metavar="TRG",
        type=str,
        help="Target language",
    )
    parser.add_argument(
        "--dataset",
        metavar="DATASET",
        type=str,
        help="Full dataset identifier. For example, sacrebleu_aug-upper-strict_wmt19 ",
    )
    parser.add_argument(
        "--output_prefix",
        metavar="OUTPUT_PREFIX",
        type=str,
        help="Write output dataset to a path with this prefix",
    )

    args = parser.parse_args()
    logger.info("Starting dataset import and augmentation.")
    run_import(args.type, args.dataset, args.output_prefix, args.src, args.trg)
    logger.info("Finished dataset import and augmentation.")


if __name__ == "__main__":
    main()
