# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import json
import logging
import subprocess
from logging import Logger
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence

from . import UserError
from .errors import Errors
from .filesystem import get_filesystem


LOG: Logger = logging.getLogger(__name__)


class Configuration:
    def __init__(
        self, path: Path, json_contents: Optional[Dict[str, Any]] = None
    ) -> None:
        if json_contents is None:
            with open(path, "r") as configuration_file:
                json_contents = json.load(configuration_file)
        self._path: Path = path
        if path.name == ".pyre_configuration.local":
            self.is_local: bool = True
        else:
            self.is_local: bool = False
        self.root: str = str(path.parent)
        self.original_contents: Dict[str, Any] = json_contents

        # Configuration fields
        self.strict: Optional[bool] = json_contents.get("strict")
        self.targets: Optional[List[str]] = json_contents.get("targets")
        self.source_directories: Optional[List[str]] = json_contents.get(
            "source_directories"
        )
        self.version: Optional[str] = json_contents.get("version")
        self.pysa_version: Optional[str] = json_contents.get("pysa_version")
        self.use_buck_builder: Optional[bool] = json_contents.get("use_buck_builder")
        self.use_buck_source_database: Optional[bool] = json_contents.get(
            "use_buck_source_database"
        )
        self.use_command_v2: Optional[bool] = json_contents.get("use_command_v2")
        self.ignore_all_errors: Optional[List[str]] = json_contents.get(
            "ignore_all_errors"
        )

    def get_contents(self) -> Dict[str, Any]:
        """Assumption: The field names in this class match the key names in
        the configuration."""
        contents: Dict[str, Any] = self.original_contents

        def update_contents(key: str) -> None:
            attribute = getattr(self, key)
            if attribute is not None:
                contents[key] = attribute
            elif key in contents:
                del contents[key]

        update_contents("targets")
        update_contents("source_directories")
        update_contents("version")
        update_contents("pysa_version")
        update_contents("strict")
        update_contents("use_buck_builder")
        update_contents("use_buck_source_database")
        update_contents("use_command_v2")
        return contents

    @staticmethod
    def find_parent_file(
        filename: str, directory: Optional[Path] = None
    ) -> Optional[Path]:
        directory = directory or Path.cwd()
        root = directory.root
        while directory != root:
            configuration_path = directory / filename
            if configuration_path.is_file():
                return configuration_path
            parent = directory.parent
            if directory == parent:
                return None
            directory = parent
        return None

    @staticmethod
    def find_project_configuration(directory: Optional[Path] = None) -> Path:
        path = Configuration.find_parent_file(".pyre_configuration", directory)
        if path is None:
            raise UserError("No root with a `.pyre_configuration` found.")
        return path

    @staticmethod
    def find_local_configuration(directory: Optional[Path] = None) -> Optional[Path]:
        return Configuration.find_parent_file(".pyre_configuration.local", directory)

    @staticmethod
    def gather_local_configuration_paths(directory: str) -> Sequence[Path]:
        return [
            Path(path)
            for path in get_filesystem().list(
                directory, patterns=[r"**\.pyre_configuration.local"]
            )
        ]

    @staticmethod
    def gather_local_configurations() -> List["Configuration"]:
        LOG.info("Finding configurations...")
        configuration_paths = Configuration.gather_local_configuration_paths(".")
        if not configuration_paths:
            LOG.info("No projects with local configurations found.")
            return []
        configurations = []
        for configuration_path in configuration_paths:
            with open(configuration_path) as configuration_file:
                try:
                    configuration = Configuration(
                        configuration_path, json.load(configuration_file)
                    )
                    configurations.append(configuration)
                except json.decoder.JSONDecodeError:
                    LOG.error(
                        "Configuration at `%s` is invalid, skipping.",
                        configuration_path,
                    )
        LOG.info(
            "Found %d local configuration%s.",
            len(configurations),
            "s" if len(configurations) != 1 else "",
        )
        return configurations

    def get_path(self) -> Path:
        return self._path

    def get_directory(self) -> Path:
        return self._path.parent

    def write(self) -> None:
        with open(self._path, "w") as configuration_file:
            json.dump(self.get_contents(), configuration_file, sort_keys=True, indent=2)
            configuration_file.write("\n")

    def remove_version(self) -> None:
        if not self.version:
            LOG.info("Version not found in configuration.")
            return
        self.version = None

    def set_version(self, version: str) -> None:
        self.version = version

    def set_pysa_version(self, pysa_version: str) -> None:
        self.pysa_version = pysa_version

    def enable_source_database_buck_builder(self) -> None:
        self.use_buck_builder = True
        self.use_buck_source_database = True

    def enable_new_server(self) -> None:
        self.use_command_v2 = True

    def add_strict(self) -> None:
        if self.strict:
            LOG.info("Configuration is already strict.")
            return
        self.strict = True

    def add_targets(self, targets: List[str]) -> None:
        existing_targets = self.targets
        if existing_targets:
            existing_targets.extend(targets)
        else:
            self.targets = targets

    def deduplicate_targets(self) -> None:
        targets = self.targets
        if not targets:
            return
        glob_targets = [target for target in targets if target.endswith("/...")]
        non_glob_targets = [target for target in targets if not target.endswith("/...")]
        all_targets = sorted(set(glob_targets)) + sorted(set(non_glob_targets))
        deduplicated_targets = []
        expanded_targets = set()
        for target in all_targets:
            if target.endswith("/...") or target.endswith(":"):
                try:
                    expanded = (
                        subprocess.check_output(["buck", "query", target])
                        .decode()
                        .strip()
                        .split("\n")
                    )
                    if not all(target in expanded_targets for target in expanded):
                        expanded_targets.update(expanded)
                        deduplicated_targets.append(target)
                except subprocess.CalledProcessError as error:
                    LOG.warning("Failed to query target: %s\n%s", target, str(error))
                    deduplicated_targets.append(target)
            elif target not in expanded_targets:
                expanded_targets.add(target)
                deduplicated_targets.append(target)
        deduplicated_targets.sort(key=lambda target: targets.index(target))
        self.targets = deduplicated_targets

    def run_pyre(
        self,
        arguments: List[str],
        description: str,
        should_clean: bool,
        command_input: Optional[str],
        stderr_flag: "subprocess._FILE" = subprocess.PIPE,
    ) -> Optional["subprocess.CompletedProcess[str]"]:
        if should_clean:
            try:
                # If building targets, run clean or space may run out on device!
                LOG.info("Running `buck clean`...")
                subprocess.call(["buck", "clean"], timeout=200)
            except subprocess.TimeoutExpired:
                LOG.warning("Buck timed out. Try running `buck kill` before retrying.")
                return None
            except subprocess.CalledProcessError as error:
                LOG.warning("Error calling `buck clean`: %s", str(error))
                return None
        try:
            LOG.info("%s", description)
            return subprocess.run(
                ["pyre", *arguments],
                stdout=subprocess.PIPE,
                stderr=stderr_flag,
                text=True,
                input=command_input,
            )
        except subprocess.CalledProcessError as error:
            LOG.warning("Error calling pyre: %s", str(error))
            return None

    def get_errors(
        self,
        only_fix_error_code: Optional[int] = None,
        should_clean: bool = True,
        command_input: Optional[str] = None,
        strict: bool = False,
    ) -> Errors:
        local_root_arguments = (
            ["--local-configuration", self.root] if self.is_local else []
        )
        strict_arguments = ["--strict"] if strict else []
        arguments = [*strict_arguments, *local_root_arguments, "--output=json", "check"]
        pyre_output = self.run_pyre(
            arguments=arguments,
            description=f"Checking `{self.root}`...",
            should_clean=self.targets is not None and should_clean,
            command_input=command_input,
        )
        if not pyre_output:
            return Errors.empty()

        stdout = pyre_output.stdout
        if stdout is None:
            return Errors.empty()

        stdout = stdout.strip()

        try:
            errors = Errors.from_json(stdout, only_fix_error_code)
        except UserError as error:
            LOG.info("Error when parsing Pyre error output.")
            LOG.info(f"Pyre stdout: {stdout}\nPyre stderr: {pyre_output.stderr}")
            raise error

        LOG.info("Found %d error%s.", len(errors), "s" if len(errors) != 1 else "")
        return errors
