# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
from __future__ import annotations

import ast
import builtins
from contextlib import nullcontext
from symtable import SymbolTable as PythonSymbolTable, SymbolTableFactory
from types import CodeType
from typing import (
    Callable,
    ContextManager,
    Iterable,
    List,
    Optional,
    Tuple,
    final,
    Dict,
    Set,
    TYPE_CHECKING,
)

from _static import __build_cinder_class__
from _strictmodule import (
    StrictAnalysisResult,
    StrictModuleLoader,
    NONSTRICT_MODULE_KIND,
    STATIC_MODULE_KIND,
    STUB_KIND_MASK_TYPING,
)

from ..errors import TypedSyntaxError
from ..pycodegen import compile as python_compile
from ..readonly import readonly_compile
from ..static import Compiler as StaticCompiler, ModuleTable, StaticCodeGenerator
from . import strict_compile
from .class_conflict_checker import check_class_conflict
from .common import StrictModuleError
from .rewriter import rewrite, remove_annotations

if TYPE_CHECKING:
    from _strictmodule import IStrictModuleLoader, StrictModuleLoaderFactory


def getSymbolTable(mod: StrictAnalysisResult) -> PythonSymbolTable:
    """
    Construct a symtable object from analysis result
    """
    return SymbolTableFactory()(mod.symtable, mod.file_name)


TIMING_LOGGER_TYPE = Callable[[str, str, str], ContextManager[None]]


@final
class Compiler(StaticCompiler):
    def __init__(
        self,
        import_path: Iterable[str],
        stub_root: str,
        allow_list_prefix: Iterable[str],
        allow_list_exact: Iterable[str],
        log_time_func: Optional[Callable[[], TIMING_LOGGER_TYPE]] = None,
        raise_on_error: bool = False,
        enable_patching: bool = False,
        loader_factory: StrictModuleLoaderFactory = StrictModuleLoader,
        use_py_compiler: bool = False,
    ) -> None:
        super().__init__(StaticCodeGenerator)
        self.import_path: List[str] = list(import_path)
        self.stub_root = stub_root
        self.allow_list_prefix = allow_list_prefix
        self.allow_list_exact = allow_list_exact
        self.loader: IStrictModuleLoader = loader_factory(
            self.import_path,
            str(stub_root),
            list(allow_list_prefix),
            list(allow_list_exact),
            True,
        )
        self.raise_on_error = raise_on_error
        self.log_time_func = log_time_func
        self.enable_patching = enable_patching
        self.track_import_call: bool = False
        self.not_static: Set[str] = set()
        self.use_py_compiler = use_py_compiler
        self.original_builtins: Dict[str, object] = dict(__builtins__)

    def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]:
        res = self.modules.get(name)
        if res is not None:
            return res

        if name in self.not_static:
            return None

        mod = self.loader.check(name)
        if mod.is_valid and name not in self.modules and len(mod.errors) == 0:
            modKind = mod.module_kind
            if modKind == STATIC_MODULE_KIND:
                root = mod.ast_preprocessed
                stubKind = mod.stub_kind
                if STUB_KIND_MASK_TYPING & stubKind:
                    root = remove_annotations(root)
                root = self._get_rewritten_ast(name, mod, root, optimize)
                log = self.log_time_func
                ctx = (
                    log()(name, mod.file_name, "declaration_visit")
                    if log
                    else nullcontext()
                )
                with ctx:
                    root = self.add_module(name, mod.file_name, root, optimize)
            else:
                self.not_static.add(name)

        return self.modules.get(name)

    def _get_rewritten_ast(
        self, name: str, mod: StrictAnalysisResult, root: ast.Module, optimize: int
    ) -> ast.Module:
        symbols = getSymbolTable(mod)
        return rewrite(
            root,
            symbols,
            mod.file_name,
            name,
            optimize=optimize,
            is_static=True,
            track_import_call=self.track_import_call,
            builtins=self.original_builtins,
        )

    def load_compiled_module_from_source(
        self,
        source: str | bytes,
        filename: str,
        name: str,
        optimize: int,
        submodule_search_locations: Optional[List[str]] = None,
        track_import_call: bool = False,
        force_strict: bool = False,
    ) -> Tuple[CodeType | None, bool]:
        if force_strict:
            self.loader.set_force_strict_by_name(name)
        mod = self.loader.check_source(
            source, filename, name, submodule_search_locations or []
        )
        errors = mod.errors
        is_valid_strict = (
            mod.is_valid
            and len(errors) == 0
            and (force_strict or (mod.module_kind != NONSTRICT_MODULE_KIND))
        )
        if errors and self.raise_on_error:
            # if raise on error, just raise the first error
            error = errors[0]
            raise StrictModuleError(error[0], error[1], error[2], error[3])
        elif is_valid_strict:
            symbols = getSymbolTable(mod)
            try:
                check_class_conflict(mod.ast, filename, symbols)
            except StrictModuleError as e:
                if self.raise_on_error:
                    raise
                mod.errors.append((e.msg, e.filename, e.lineno, e.col))

        if not is_valid_strict:
            code = self._compile_basic(mod.ast, filename, optimize)
        elif mod.module_kind == STATIC_MODULE_KIND:
            code = self._compile_static(
                mod, filename, name, optimize, track_import_call
            )
        else:
            code = self._compile_strict(
                mod, filename, name, optimize, track_import_call
            )

        return code, is_valid_strict

    def _compile_readonly(
        self, name: str, root: ast.Module, filename: str, optimize: int
    ) -> CodeType:
        """
        TODO: this should eventually replace compile_basic once all python sources
        are compiled through this compiler
        """
        return readonly_compile(name, filename, root, flags=0, optimize=optimize)

    def _compile_basic(
        self, root: ast.Module, filename: str, optimize: int
    ) -> CodeType:
        compile_method = python_compile if self.use_py_compiler else compile
        return compile_method(
            root,
            filename,
            "exec",
            optimize=optimize,
        )

    def _compile_strict(
        self,
        mod: StrictAnalysisResult,
        filename: str,
        name: str,
        optimize: int,
        track_import_call: bool,
    ) -> CodeType:
        symbols = getSymbolTable(mod)
        tree = rewrite(
            mod.ast_preprocessed,
            symbols,
            filename,
            name,
            optimize=optimize,
            track_import_call=track_import_call,
            builtins=self.original_builtins,
        )
        return strict_compile(name, filename, tree, optimize, self.original_builtins)

    def _compile_static(
        self,
        mod: StrictAnalysisResult,
        filename: str,
        name: str,
        optimize: int,
        track_import_call: bool,
    ) -> CodeType | None:
        self._ensure_build_class_patched()
        self.track_import_call = track_import_call
        root = self.ast_cache.get(name)
        if root is None:
            root = self._get_rewritten_ast(name, mod, mod.ast_preprocessed, optimize)
        code = None

        try:
            log = self.log_time_func
            ctx = log()(name, filename, "compile") if log else nullcontext()
            with ctx:
                code = self.compile(
                    name,
                    filename,
                    root,
                    optimize,
                    enable_patching=self.enable_patching,
                    builtins=self.original_builtins,
                )
        except TypedSyntaxError as e:
            err = StrictModuleError(
                e.msg or "unknown error during static compilation",
                e.filename or filename,
                e.lineno or 1,
                0,
            )
            mod.errors.append(err)
            if self.raise_on_error:
                raise err

        return code

    def _ensure_build_class_patched(self) -> None:
        # pyre-ignore[61]: __build_class__ isn't exposed in builtins.pyi.
        if builtins.__build_class__ is not __build_cinder_class__:
            builtins.__build_class__ = __build_cinder_class__
