# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
import ast
from ast import AST
from compiler import compile as compiler_compile
from compiler.optimizer import BIN_OPS, is_const, get_const_value
from compiler.py38.optimizer import AstOptimizer38
from compiler.pyassem import PyFlowGraph38
from compiler.pycodegen import Python38CodeGenerator
from compiler.symbols import SymbolVisitor
from compiler.visitor import ASTVisitor, walk

import _compiler_opcode as opcodepyro


def should_rewrite_printf(node):
    return isinstance(node.left, ast.Str) and isinstance(node.op, ast.Mod)


def create_conversion_call(name, value):
    method = ast.Attribute(ast.Str(""), name, ast.Load())
    return ast.Call(method, args=[value], keywords=[])


def try_constant_fold_mod(format_string, right):
    r = get_const_value(right)
    return ast.Str(format_string.__mod__(r))


class AstOptimizerPyro(AstOptimizer38):
    def rewrite_str_mod(self, left, right):  # noqa: C901
        format_string = left.s
        try:
            if is_const(right):
                return try_constant_fold_mod(format_string, right)
            # Try and collapse the whole expression into a string
            const_tuple = self.makeConstTuple(right.elts)
            if const_tuple:
                return ast.Str(format_string.__mod__(const_tuple.value))
        except Exception:
            pass
        n_specifiers = 0
        i = 0
        length = len(format_string)
        while i < length:
            i = format_string.find("%", i)
            if i == -1:
                break
            ch = format_string[i]
            i += 1

            if i >= length:
                # Invalid format string ending in a single percent
                return None
            ch = format_string[i]
            i += 1
            if ch == "%":
                # Break the string apart at '%'
                continue
            elif ch == "(":
                # We don't support dict lookups and may get confused from
                # inner '%' chars
                return None
            n_specifiers += 1

        rhs = right
        if isinstance(right, ast.Tuple):
            rhs_values = rhs.elts
            num_values = len(rhs_values)
        else:
            # If RHS is not a tuple constructor, then we only support the
            # situation with a single format specifier in the string, by
            # normalizing `rhs` to a one-element tuple:
            # `_mod_check_single_arg(rhs)[0]`
            rhs_values = None
            if n_specifiers != 1:
                return None
            num_values = 1
        i = 0
        value_idx = 0
        segment_begin = 0
        strings = []
        while i < length:
            i = format_string.find("%", i)
            if i == -1:
                break
            ch = format_string[i]
            i += 1

            segment_end = i - 1
            if segment_end - segment_begin > 0:
                substr = format_string[segment_begin:segment_end]
                strings.append(ast.Str(substr))

            if i >= length:
                return None
            ch = format_string[i]
            i += 1

            # Parse flags and width
            spec_begin = i - 1
            have_width = False
            while True:
                if ch == "0":
                    # TODO(matthiasb): Support ' ', '+', '#', etc
                    # They mostly have the same meaning. However they can
                    # appear in any order here but must follow stricter
                    # conventions in f-strings.
                    if i >= length:
                        return None
                    ch = format_string[i]
                    i += 1
                    continue
                break
            if "1" <= ch <= "9":
                have_width = True
                if i >= length:
                    return None
                ch = format_string[i]
                i += 1
                while "0" <= ch <= "9":
                    if i >= length:
                        return None
                    ch = format_string[i]
                    i += 1
            spec_str = ""
            if i - 1 - spec_begin > 0:
                spec_str = format_string[spec_begin : i - 1]

            if ch == "%":
                # Handle '%%'
                segment_begin = i - 1
                continue

            # Handle remaining supported cases that use a value from RHS
            if rhs_values is not None:
                if value_idx >= num_values:
                    return None
                value = rhs_values[value_idx]
            else:
                # We have a situation like `"%s" % x` without tuple on RHS.
                # Transform to: f"{''._mod_check_single_arg(x)[0]}"
                converted = create_conversion_call("_mod_check_single_arg", rhs)
                value = ast.Subscript(converted, ast.Index(ast.Num(0)), ast.Load())
            value_idx += 1

            if ch in "sra":
                # Rewrite "%s" % (x,) to f"{x!s}"
                if have_width:
                    # Need to explicitly specify alignment because `%5s`
                    # aligns right, while `f"{x:5}"` aligns left.
                    spec_str = ">" + spec_str
                format_spec = ast.Str(spec_str) if spec_str else None
                formatted = ast.FormattedValue(value, ord(ch), format_spec)
                strings.append(formatted)
            elif ch in "diu":
                # Rewrite "%d" % (x,) to f"{''._mod_convert_number_int(x)}".
                # Calling a method on the empty string is a hack to access a
                # well-known function regardless of the surrounding
                # environment.
                converted = create_conversion_call("_mod_convert_number_int", value)
                format_spec = ast.Str(spec_str) if spec_str else None
                formatted = ast.FormattedValue(converted, -1, format_spec)
                strings.append(formatted)
            elif ch in "xXo":
                # Rewrite "%x" % (v,) to f"{''._mod_convert_number_index(v):x}".
                # Calling a method on the empty string is a hack to access a
                # well-known function regardless of the surrounding
                # environment.
                converted = create_conversion_call("_mod_convert_number_index", value)
                format_spec = ast.Str(spec_str + ch)
                formatted = ast.FormattedValue(converted, -1, format_spec)
                strings.append(formatted)
            else:
                return None
            # Begin next segment after specifier
            segment_begin = i

        if value_idx != num_values:
            return None

        segment_end = length
        if segment_end - segment_begin > 0:
            substr = format_string[segment_begin:segment_end]
            strings.append(ast.Str(substr))

        return ast.JoinedStr(strings)

    def visitBinOp(self, node: ast.BinOp) -> ast.expr:
        left = self.visit(node.left)
        right = self.visit(node.right)

        if is_const(left) and is_const(right):
            handler = BIN_OPS.get(type(node.op))
            if handler is not None:
                lval = get_const_value(left)
                rval = get_const_value(right)
                try:
                    return ast.copy_location(ast.Constant(handler(lval, rval)), node)
                except Exception:
                    pass

        if should_rewrite_printf(node):
            result = self.rewrite_str_mod(left, right)
            if result:
                return self.visit(result)

        return self.update_node(node, left=left, right=right)


class PyroFlowGraph(PyFlowGraph38):
    opcode = opcodepyro.opcode


class ComprehensionRenamer(ASTVisitor):
    def __init__(self, scope):
        super().__init__()
        # We need a prefix that is unique per-scope for each renaming round.
        index = getattr(scope, "last_comprehension_rename_index", -1) + 1
        scope.last_comprehension_rename_index = index
        self.prefix = f"_gen{str(index) if index > 0 else ''}$"
        self.new_names = {}
        self.is_target = False

    def visitName(self, node):
        if self.is_target and isinstance(node.ctx, (ast.Store, ast.Del)):
            name = node.id
            new_name = self.prefix + name
            self.new_names[name] = new_name
            node.id = new_name
        else:
            new_name = self.new_names.get(node.id)
            if new_name is not None:
                node.id = new_name

    def visitarg(self, node):
        new_name = self.new_names.get(node.arg)
        if new_name is not None:
            node.arg = new_name


class PyroSymbolVisitor(SymbolVisitor):
    def visitDictCompListCompSetComp(self, node, scope):
        # Check for unexpected assignments.
        scope.comp_iter_expr += 1
        self.visit(node.generators[0].iter, scope)
        scope.comp_iter_expr -= 1

        renamer = ComprehensionRenamer(scope)
        is_outer = True
        for gen in node.generators:
            renamer.visit(gen.iter)
            renamer.is_target = True
            renamer.visit(gen.target)
            renamer.is_target = False
            for if_node in gen.ifs:
                renamer.visit(if_node)

            self.visitcomprehension(gen, scope, is_outer)
            is_outer = False

        if isinstance(node, ast.DictComp):
            renamer.visit(node.value)
            renamer.visit(node.key)
            self.visit(node.value, scope)
            self.visit(node.key, scope)
        else:
            renamer.visit(node.elt)
            self.visit(node.elt, scope)

    visitDictComp = visitDictCompListCompSetComp
    visitListComp = visitDictCompListCompSetComp
    visitSetComp = visitDictCompListCompSetComp


class PyroCodeGenerator(Python38CodeGenerator):
    flow_graph = PyroFlowGraph

    @classmethod
    def make_code_gen(
        cls,
        name: str,
        tree: AST,
        filename: str,
        flags: int,
        optimize: int,
        peephole_enabled: bool = True,
        ast_optimizer_enabled: bool = True,
    ):
        if ast_optimizer_enabled:
            tree = cls.optimize_tree(optimize, tree)
        s = PyroSymbolVisitor()
        walk(tree, s)

        graph = cls.flow_graph(
            name, filename, s.scopes[tree], peephole_enabled=peephole_enabled
        )
        code_gen = cls(None, tree, s, graph, flags, optimize)
        walk(tree, code_gen)
        return code_gen

    @classmethod
    def optimize_tree(cls, optimize: int, tree: ast.AST):
        return AstOptimizerPyro(optimize=optimize > 0).visit(tree)

    def defaultEmitCompare(self, op):
        if isinstance(op, ast.Is):
            self.emit("COMPARE_IS")
        elif isinstance(op, ast.IsNot):
            self.emit("COMPARE_IS_NOT")
        else:
            self.emit("COMPARE_OP", self._cmp_opcode[type(op)])

    def visitListComp(self, node):
        self.emit("BUILD_LIST")
        self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))

    def visitSetComp(self, node):
        self.emit("BUILD_SET")
        self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))

    def visitDictComp(self, node):
        self.emit("BUILD_MAP")
        self.compile_comprehension_body(
            node.generators, 0, node.key, node.value, type(node)
        )


def compile(source, filename, mode, flags, dont_inherit, optimize):
    return compiler_compile(
        source, filename, mode, flags, None, optimize, PyroCodeGenerator
    )
