# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools that downloads 🤗 Transformers training script examples and prepares them for AWS Trainium instances."""

import re
import shutil
import subprocess
from argparse import ArgumentParser
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Callable, List, Optional, Tuple, Union

from git import Repo


REPO_URL = "https://github.com/huggingface/transformers.git"

SUPPORTED_EXAMPLES = [
    "text-classification",
    "token-classification",
    "question-answering",
    "multiple-choice",
    "image-classification",
    "language-modeling",
    "translation",
    "summarization",
]

UNSUPPORTED_SCRIPTS_FOR_NOW = [
    "run_plm.py",
    "run_qa_beam_search.py",
]


IMPORT_PATTERN = r"from transformers import \(?[\w\s,_]*?([\t ]*({class_pattern}),?\n?)((?!from)[\w\s,_])*\)?"

# TRAINER_IMPORT_PATTERN = re.compile(
#     r"from transformers import \(?[\w\s,_]*?([\t ]*((Seq2Seq)?Trainer),?\n?)((?!from)[\w\s,_])*\)?"
# )

TRAINER_IMPORT_PATTERN = re.compile(IMPORT_PATTERN.format(class_pattern="(Seq2Seq)?Trainer"))
HF_ARGUMENT_PARSER_IMPORT_PATTERN = re.compile(IMPORT_PATTERN.format(class_pattern="HfArgumentParser"))
TRAINING_ARGUMENTS_IMPORT_PATTERN = re.compile(IMPORT_PATTERN.format(class_pattern="(Seq2Seq)?TrainingArguments"))
# HF_ARGUMENT_PARSER_IMPORT_PATTERN = re.compile(
#     r"from transformers import \(?[\w\s,_]*?([\t ]*(HfArgumentParser),?\n?)((?!from)[\w\s,_])*\)?"
# )
# TRAINING_ARGUMENTS_IMPORT_PATTERN = re.compile(
#     r"from transformers import \(?[\w\s,_]*?([\t ]*(TrainingArguments),?\n?)((?!from)[\w\s,_])*\)?"
# )


TORCH_REQUIREMENT_PATTERN = re.compile(r"torch[\w\s]*([<>=!]=?\s*[\d\.]+)?\n")


AWS_CODE = {
    "Trainer": "NeuronTrainer as Trainer",
    "Seq2SeqTrainer": "Seq2SeqNeuronTrainer as Seq2SeqTrainer",
    "HfArgumentParser": "NeuronHfArgumentParser as HfArgumentParser",
    "TrainingArguments": "NeuronTrainingArguments as TrainingArguments",
    "Seq2SeqTrainingArguments": "Seq2SeqNeuronTrainingArguments as Seq2SeqTrainingArguments",
}


def download_examples_from_transformers(
    example_names: List[str],
    dest_dir: Union[str, Path],
    predicate: Optional[Callable[[Path], bool]] = None,
    version: Optional[str] = None,
):
    if isinstance(dest_dir, str):
        dest_dir = Path(dest_dir)

    if predicate is None:

        def predicate(_):
            return True

    with TemporaryDirectory() as tmpdirname:
        repo = Repo.clone_from(REPO_URL, tmpdirname)
        if version is not None:
            pattern = rf"v{version}-(release|patch)"
            match_ = re.search(pattern, repo.git.branch("--all"))
            if match_ is None:
                raise ValueError(f"Could not find the {version} version in the Transformers repo.")
            repo.git.checkout(match_.group(0))

        path_prefix = Path(tmpdirname) / "examples" / "pytorch"
        dest_dir.mkdir(parents=True, exist_ok=True)

        for example in example_names:
            example_dir = path_prefix / example
            for file_path in example_dir.iterdir():
                if predicate(file_path):
                    dest_example_dir = dest_dir / example
                    dest_example_dir.mkdir(parents=True, exist_ok=True)
                    shutil.copy(file_path, dest_example_dir / file_path.name)


def keep_only_examples_with_trainer_and_requirements_predicate(file_path: Path) -> bool:
    is_python_or_text = file_path.suffix in [".py", ".txt"]
    is_supported = file_path.name not in UNSUPPORTED_SCRIPTS_FOR_NOW
    not_a_no_trainer_script = "no_trainer" not in file_path.name
    is_requirements = file_path.name == "requirements.txt"
    return is_python_or_text and is_supported and (not_a_no_trainer_script or is_requirements)


def remove_import(pattern: re.Pattern, file_content: str) -> Tuple[str, str, int]:
    match_ = re.search(pattern, file_content)
    if match_ is None:
        raise ValueError(f"Could not find a match for pattern {pattern}.")
    cls_ = match_.group(2)
    new_content = file_content[: match_.start(1)] + file_content[match_.end(1) :]
    return cls_, new_content, match_.end(0) - (match_.end(1) - match_.start(1))


def remove_trainer_import(file_content: str) -> tuple[str, str, int]:
    match_ = re.search(TRAINER_IMPORT_PATTERN, file_content)
    if match_ is None:
        raise ValueError("Could not find the import of the Trainer class from transformers.")
    trainer_cls = match_.group(2)
    new_content = file_content[: match_.start(1)] + file_content[match_.end(1) :]
    return trainer_cls, new_content, match_.end(0) - (match_.end(1) - match_.start(1))


def insert_code_at_position(code: str, file_content: str, position: int) -> str:
    return file_content[:position] + code + file_content[position:]


def generate_new_import_code(*optimum_neuron_imports: str) -> str:
    if not optimum_neuron_imports:
        raise ValueError("At least one import is expected to generate new import code.")
    import_line = ["from optimum.neuron import"]
    import_line += [f"{import_}," for import_ in optimum_neuron_imports[:-1]]
    import_line.append(optimum_neuron_imports[-1])
    return " ".join(import_line)


def parse_args():
    parser = ArgumentParser(
        description="Tool to download and prepare 🤗 Transformers example training scripts for AWS Trainium instances."
    )
    parser.add_argument(
        "--version",
        default=None,
        type=str,
        help="The version of Transformers from which the examples will be downloaded. By default the main branch is used.",
    )
    parser.add_argument(
        "--examples",
        default="all",
        action="store",
        type=str,
        nargs="+",
        help="The names of the examples to download. By default all the supported examples will be downloaded.",
    )
    parser.add_argument("dest", type=Path, help="The directory in which the examples will be saved.")
    return parser.parse_args()


def main():
    args = parse_args()
    examples = args.examples
    if examples == "all":
        examples = SUPPORTED_EXAMPLES
    download_examples_from_transformers(
        examples, args.dest, predicate=keep_only_examples_with_trainer_and_requirements_predicate, version=args.version
    )

    for example_dir in args.dest.iterdir():
        if example_dir.is_file():
            continue
        for file_path in example_dir.iterdir():
            if file_path.name == "run_generation.py":
                continue
            if "run" in file_path.name and file_path.suffix == ".py":
                if file_path.name == "run_qa.py":
                    trainer_file_path = file_path.parent / "trainer_qa.py"
                elif file_path.name == "run_seq2seq_qa.py":
                    trainer_file_path = file_path.parent / "trainer_seq2seq_qa.py"
                else:
                    trainer_file_path = file_path
                hf_argument_file_path = file_path
                training_argument_file_path = file_path

                print(f"Processing {file_path}")
                with open(trainer_file_path, "r") as fp:
                    file_content = fp.read()
                trainer_cls, processed_content, import_end_index = remove_import(TRAINER_IMPORT_PATTERN, file_content)
                code = generate_new_import_code(AWS_CODE[trainer_cls])
                code = f"\n{code}\n"
                processed_content = insert_code_at_position(code, processed_content, import_end_index)
                with open(trainer_file_path, "w") as fp:
                    fp.write(processed_content)

                with open(hf_argument_file_path, "r") as fp:
                    file_content = fp.read()
                _, processed_content, import_end_index = remove_import(HF_ARGUMENT_PARSER_IMPORT_PATTERN, file_content)
                code = generate_new_import_code(AWS_CODE["HfArgumentParser"])
                code = f"\n{code}\n"
                processed_content = insert_code_at_position(code, processed_content, import_end_index)
                with open(hf_argument_file_path, "w") as fp:
                    fp.write(processed_content)

                with open(training_argument_file_path, "r") as fp:
                    file_content = fp.read()
                training_args_cls, processed_content, import_end_index = remove_import(
                    TRAINING_ARGUMENTS_IMPORT_PATTERN, file_content
                )
                code = generate_new_import_code(AWS_CODE[training_args_cls])
                code = f"\n{code}\n"
                processed_content = insert_code_at_position(code, processed_content, import_end_index)
                with open(training_argument_file_path, "w") as fp:
                    fp.write(processed_content)

            elif file_path.name == "requirements.txt":
                with open(file_path, "r") as fp:
                    file_content = fp.read()
                processed_content = re.sub(TORCH_REQUIREMENT_PATTERN, "", file_content)
                if file_path.parent.name == "image-classification":
                    processed_content += "\nscikit-learn"
                with open(file_path, "w") as fp:
                    fp.write(processed_content)

    # Linting and styling.
    subprocess.run(["ruff", f"{args.dest}", "--fix"])


if __name__ == "__main__":
    main()
