# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
# pyre-unsafe
import ast
from typing import Any, Callable, Dict, List, Optional, Type


PR_TUPLE = 0
PR_TEST = 1  # 'if'-'else', 'lambda'
PR_OR = 2  # 'or'
PR_AND = 3  # 'and'
PR_NOT = 4  # 'not'
PR_CMP = 5  # '<', '>', '==', '>=', '<=', '!=' 'in', 'not in', 'is', 'is not'
PR_EXPR = 6
PR_BOR = PR_EXPR  # '|'
PR_BXOR = 7  # '^'
PR_BAND = 8  # '&'
PR_SHIFT = 9  # '<<', '>>'
PR_ARITH = 10  # '+', '-'
PR_TERM = 11  # '*', '@', '/', '%', '//'
PR_FACTOR = 12  # unary '+', '-', '~'
PR_POWER = 13  # '**'
PR_AWAIT = 14  # 'await'
PR_ATOM = 15


def get_op(node: ast.cmpop) -> str:
    if isinstance(node, ast.Is):
        return " is "
    elif isinstance(node, ast.IsNot):
        return " is not "
    elif isinstance(node, ast.In):
        return " in "
    elif isinstance(node, ast.NotIn):
        return " not in "
    elif isinstance(node, ast.Lt):
        return " < "
    elif isinstance(node, ast.Gt):
        return " > "
    elif isinstance(node, ast.LtE):
        return " <= "
    elif isinstance(node, ast.GtE):
        return " >= "
    elif isinstance(node, ast.Eq):
        return " == "
    elif isinstance(node, ast.NotEq):
        return " != "
    else:
        return "unknown op: " + type(node).__name__


def get_unop(node: ast.unaryop) -> str:
    if isinstance(node, ast.UAdd):
        return "+"
    elif isinstance(node, ast.USub):
        return "-"
    elif isinstance(node, ast.Not):
        return "not "
    elif isinstance(node, ast.Invert):
        return "~"
    return "<unknown unary op>"


def _format_name(node: ast.Name, level: int) -> str:
    return node.id


def _format_compare(node: ast.Compare, level: int) -> str:
    return parens(
        level,
        PR_CMP,
        to_expr(node.left, PR_CMP + 1)
        + "".join(
            (
                get_op(op) + to_expr(comp, PR_CMP + 1)
                for comp, op in zip(node.comparators, node.ops)
            )
        ),
    )


def _format_nameconstant(node: ast.NameConstant, level: int) -> str:
    if node.value is None:
        return "None"
    elif node.value is True:
        return "True"
    elif node.value is False:
        return "False"
    return "<unknown constant>"


def _format_num(node: ast.Num, level: int) -> str:
    return repr(node.n)


def _format_str(node: ast.Str, level: int) -> str:
    return repr(node.s)


def _format_attribute(node: ast.Attribute, level: int) -> str:
    value = to_expr(node.value, PR_ATOM)
    const = node.value
    if (isinstance(const, ast.Constant) and isinstance(const.value, int)) or type(
        const
    ) is ast.Num:
        value += " ."
    else:
        value += "."
    return value + node.attr


def _format_tuple(node: ast.Tuple, level: int) -> str:
    if not node.elts:
        return "()"
    elif len(node.elts) == 1:
        return parens(level, PR_TUPLE, to_expr(node.elts[0]) + ",")

    return parens(level, PR_TUPLE, ", ".join(to_expr(elm) for elm in node.elts))


def _format_list(node: ast.List, level: int) -> str:
    return "[" + ", ".join(to_expr(elm) for elm in node.elts) + "]"


def _format_kw(node: ast.keyword):
    if node.arg:
        return f"{node.arg}={to_expr(node.value)}"
    return f"**{to_expr(node.value)}"
    pass


def _format_call(node: ast.Call, level: int) -> str:
    args = [to_expr(arg) for arg in node.args] + [
        _format_kw(arg) for arg in node.keywords
    ]
    return to_expr(node.func, PR_TEST) + "(" + ", ".join(args) + ")"


def _format_unaryop(node: ast.UnaryOp, level: int) -> str:
    tgt_level = PR_FACTOR
    if isinstance(node.op, ast.Not):
        tgt_level = PR_NOT
    return parens(
        level, tgt_level, get_unop(node.op) + to_expr(node.operand, tgt_level)
    )


BIN_OPS = {
    ast.Add: (" + ", PR_ARITH),
    ast.Sub: (" - ", PR_ARITH),
    ast.Mult: (" * ", PR_TERM),
    ast.MatMult: (" @ ", PR_TERM),
    ast.Div: (" / ", PR_TERM),
    ast.Mod: (" % ", PR_TERM),
    ast.LShift: (" << ", PR_SHIFT),
    ast.RShift: (" >> ", PR_SHIFT),
    ast.BitOr: (" | ", PR_BOR),
    ast.BitXor: (" ^ ", PR_BXOR),
    ast.BitAnd: (" & ", PR_BAND),
    ast.FloorDiv: (" // ", PR_TERM),
    ast.Pow: (" ** ", PR_POWER),
}


def _format_binaryop(node: ast.BinOp, level: int) -> str:
    tgt_level = PR_FACTOR

    op, tgt_level = BIN_OPS[type(node.op)]
    rassoc = 0
    if isinstance(node.op, ast.Pow):
        rassoc = 1
    return parens(
        level,
        tgt_level,
        to_expr(node.left, tgt_level + rassoc)
        + op
        + to_expr(node.right, tgt_level + (1 - rassoc)),
    )


def _format_subscript(node: ast.Subscript, level: int) -> str:
    return f"{to_expr(node.value, PR_ATOM)}[{to_expr(node.slice)}]"


def _format_index(node: ast.Index, level: int) -> str:
    return to_expr(node.value, PR_TUPLE)


def _format_yield(node: ast.Yield, level: int) -> str:
    if node.value:
        return "(yield " + to_expr(node.value) + ")"
    return "(yield)"


def _format_yield_from(node: ast.YieldFrom, level: int) -> str:
    return "(yield from " + to_expr(node.value) + ")"


def _format_dict(node: ast.Dict, level: int) -> str:
    return (
        "{"
        + ", ".join(
            to_expr(k) + ": " + to_expr(v) for k, v in zip(node.keys, node.values)
        )
        + "}"
    )


def _format_comprehension(node: ast.comprehension) -> str:
    header = " for "
    if node.is_async:
        header = " async for "

    res = (
        header
        + to_expr(node.target, PR_TUPLE)
        + " in "
        + to_expr(node.iter, PR_TEST + 1)
    )
    for if_ in node.ifs:
        res += " if " + to_expr(if_, PR_TEST + 1)
    return res


def parens(level: int, target_lvl: int, value: str) -> str:
    if level > target_lvl:
        return f"({value})"
    return value


def _format_await(node: ast.Await, level: int):
    return parens(level, PR_AWAIT, "await " + to_expr(node.value, PR_ATOM))


def _format_starred(node: ast.Starred, level: int):
    return "*" + to_expr(node.value, PR_EXPR)


def _format_boolop(node: ast.BoolOp, level: int) -> str:
    if isinstance(node.op, ast.And):
        name = " and "
        tgt_level = PR_AND
    else:
        name = " or "
        tgt_level = PR_OR

    return parens(
        level, tgt_level, name.join(to_expr(n, tgt_level + 1) for n in node.values)
    )


def _format_arguments(node: ast.arguments) -> str:
    res = []
    for i, arg in enumerate(node.args):
        if i:
            res.append(", ")
        res.append(arg.arg)
        if i < len(node.defaults):
            res.append("=")
            res.append(to_expr(node.defaults[i]))

    if node.vararg or node.kwonlyargs:
        if node.args:
            res.append(", ")
        res.append("*")
        vararg = node.vararg
        if vararg:
            res.append(vararg.arg)

    for i, arg in enumerate(node.kwonlyargs):
        if res:
            res.append(", ")
        res.append(arg.arg)
        if i < len(node.kw_defaults) and node.kw_defaults[i]:
            res.append("=")
            res.append(to_expr(node.kw_defaults[i]))

    return "".join(res)


def _format_lambda(node: ast.Lambda, level: int) -> str:
    value = "lambda "
    if not node.args.args:
        value = "lambda"
    value += _format_arguments(node.args)
    value += ": " + to_expr(node.body, PR_TEST)

    return parens(level, PR_TEST, value)


def _format_if_exp(node: ast.IfExp, level: int) -> str:
    body = to_expr(node.body, PR_TEST + 1)
    orelse = to_expr(node.orelse, PR_TEST)
    test = to_expr(node.test, PR_TEST + 1)
    return parens(level, PR_TEST, f"{body} if {test} else {orelse}")


def _format_set(node: ast.Set, level: int) -> str:
    return "{" + ", ".join(to_expr(elt, PR_TEST) for elt in node.elts) + "}"


def _format_comprehensions(nodes: List[ast.comprehension]) -> str:
    return "".join(_format_comprehension(n) for n in nodes)


def _format_set_comp(node: ast.SetComp, level: int) -> str:
    return "{" + to_expr(node.elt) + _format_comprehensions(node.generators) + "}"


def _format_list_comp(node: ast.ListComp, level: int) -> str:
    return "[" + to_expr(node.elt) + _format_comprehensions(node.generators) + "]"


def _format_dict_comp(node: ast.DictComp, level: int) -> str:
    return (
        "{"
        + to_expr(node.key)
        + ": "
        + to_expr(node.value)
        + _format_comprehensions(node.generators)
        + "}"
    )


def _format_gen_exp(node: ast.GeneratorExp, level: int) -> str:
    return "(" + to_expr(node.elt) + _format_comprehensions(node.generators) + ")"


def format_fstring_elt(res: List[str], elt: ast.expr, is_format_spec: bool):
    if isinstance(elt, ast.Str):
        res.append(elt.s)
    elif isinstance(elt, ast.Constant):
        res.append(elt.value)
    elif isinstance(elt, ast.JoinedStr):
        res.append(format_joinedstr(elt, PR_TEST, is_format_spec))
    elif isinstance(elt, ast.FormattedValue):
        expr = to_expr(elt.value, PR_TEST + 1)
        if expr.startswith("{"):
            # Expression starts with a brace, we need an extra space
            res.append("{ ")
        else:
            res.append("{")
        res.append(expr)
        conversion = elt.conversion
        if conversion is not None and conversion != -1:
            res.append("!")
            res.append(chr(conversion))
        format_spec = elt.format_spec
        if format_spec is not None:
            res.append(":")
            format_fstring_elt(res, format_spec, True)
        res.append("}")


def format_joinedstr(node: ast.JoinedStr, level: int, is_format_spec=False) -> str:
    res = []
    for elt in node.values:
        format_fstring_elt(res, elt, is_format_spec)

    joined = "".join(res)
    if is_format_spec:
        return joined
    return f"f{repr(joined)}"


def _format_slice(node: ast.Slice, level: int):
    res = ""
    if node.lower is not None:
        res += to_expr(node.lower)
    res += ":"
    if node.upper is not None:
        res += to_expr(node.upper)
    if node.step:
        res += ":"
        res += to_expr(node.step)
    return res


def _format_extslice(node: ast.ExtSlice, level: int):
    return ", ".join(to_expr(d) for d in node.dims)


def _format_constant(node: ast.Constant, level: int):
    if node.value is Ellipsis:
        return "..."

    return repr(node.value)


_FORMATTERS: Dict[Type, Callable[[Any, int], str]] = {
    ast.BoolOp: _format_boolop,
    ast.BinOp: _format_binaryop,
    ast.UnaryOp: _format_unaryop,
    ast.Lambda: _format_lambda,
    ast.IfExp: _format_if_exp,
    ast.Dict: _format_dict,
    ast.Set: _format_set,
    ast.GeneratorExp: _format_gen_exp,
    ast.ListComp: _format_list_comp,
    ast.SetComp: _format_set_comp,
    ast.DictComp: _format_dict_comp,
    ast.Yield: _format_yield,
    ast.YieldFrom: _format_yield_from,
    ast.Await: _format_await,
    ast.Compare: _format_compare,
    ast.Call: _format_call,
    ast.Constant: _format_constant,
    ast.Num: _format_num,
    ast.Str: _format_str,
    ast.JoinedStr: format_joinedstr,
    ast.FormattedValue: None,
    ast.Bytes: lambda node, level: repr(node.s),
    ast.Ellipsis: lambda node, level: "...",
    ast.NameConstant: _format_nameconstant,
    ast.Attribute: _format_attribute,
    ast.Subscript: _format_subscript,
    ast.Starred: _format_starred,
    ast.Name: _format_name,
    ast.List: _format_list,
    ast.Tuple: _format_tuple,
    ast.Slice: _format_slice,
    ast.ExtSlice: _format_extslice,
    ast.Index: _format_index,
}


def to_expr(node: Optional[ast.AST], level=PR_TEST) -> str:
    if node is None:
        return ""
    formatter = _FORMATTERS.get(type(node))
    if formatter is not None:
        return formatter(node, level)

    return "<unsupported node: " + type(node).__name__ + ">"
