library/compiler/unparse.py (335 lines of code) (raw):
# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
# pyre-unsafe
import ast
from typing import Any, Callable, Dict, List, Optional, Type
PR_TUPLE = 0
PR_TEST = 1 # 'if'-'else', 'lambda'
PR_OR = 2 # 'or'
PR_AND = 3 # 'and'
PR_NOT = 4 # 'not'
PR_CMP = 5 # '<', '>', '==', '>=', '<=', '!=' 'in', 'not in', 'is', 'is not'
PR_EXPR = 6
PR_BOR = PR_EXPR # '|'
PR_BXOR = 7 # '^'
PR_BAND = 8 # '&'
PR_SHIFT = 9 # '<<', '>>'
PR_ARITH = 10 # '+', '-'
PR_TERM = 11 # '*', '@', '/', '%', '//'
PR_FACTOR = 12 # unary '+', '-', '~'
PR_POWER = 13 # '**'
PR_AWAIT = 14 # 'await'
PR_ATOM = 15
def get_op(node: ast.cmpop) -> str:
if isinstance(node, ast.Is):
return " is "
elif isinstance(node, ast.IsNot):
return " is not "
elif isinstance(node, ast.In):
return " in "
elif isinstance(node, ast.NotIn):
return " not in "
elif isinstance(node, ast.Lt):
return " < "
elif isinstance(node, ast.Gt):
return " > "
elif isinstance(node, ast.LtE):
return " <= "
elif isinstance(node, ast.GtE):
return " >= "
elif isinstance(node, ast.Eq):
return " == "
elif isinstance(node, ast.NotEq):
return " != "
else:
return "unknown op: " + type(node).__name__
def get_unop(node: ast.unaryop) -> str:
if isinstance(node, ast.UAdd):
return "+"
elif isinstance(node, ast.USub):
return "-"
elif isinstance(node, ast.Not):
return "not "
elif isinstance(node, ast.Invert):
return "~"
return "<unknown unary op>"
def _format_name(node: ast.Name, level: int) -> str:
return node.id
def _format_compare(node: ast.Compare, level: int) -> str:
return parens(
level,
PR_CMP,
to_expr(node.left, PR_CMP + 1)
+ "".join(
(
get_op(op) + to_expr(comp, PR_CMP + 1)
for comp, op in zip(node.comparators, node.ops)
)
),
)
def _format_nameconstant(node: ast.NameConstant, level: int) -> str:
if node.value is None:
return "None"
elif node.value is True:
return "True"
elif node.value is False:
return "False"
return "<unknown constant>"
def _format_num(node: ast.Num, level: int) -> str:
return repr(node.n)
def _format_str(node: ast.Str, level: int) -> str:
return repr(node.s)
def _format_attribute(node: ast.Attribute, level: int) -> str:
value = to_expr(node.value, PR_ATOM)
const = node.value
if (isinstance(const, ast.Constant) and isinstance(const.value, int)) or type(
const
) is ast.Num:
value += " ."
else:
value += "."
return value + node.attr
def _format_tuple(node: ast.Tuple, level: int) -> str:
if not node.elts:
return "()"
elif len(node.elts) == 1:
return parens(level, PR_TUPLE, to_expr(node.elts[0]) + ",")
return parens(level, PR_TUPLE, ", ".join(to_expr(elm) for elm in node.elts))
def _format_list(node: ast.List, level: int) -> str:
return "[" + ", ".join(to_expr(elm) for elm in node.elts) + "]"
def _format_kw(node: ast.keyword):
if node.arg:
return f"{node.arg}={to_expr(node.value)}"
return f"**{to_expr(node.value)}"
pass
def _format_call(node: ast.Call, level: int) -> str:
args = [to_expr(arg) for arg in node.args] + [
_format_kw(arg) for arg in node.keywords
]
return to_expr(node.func, PR_TEST) + "(" + ", ".join(args) + ")"
def _format_unaryop(node: ast.UnaryOp, level: int) -> str:
tgt_level = PR_FACTOR
if isinstance(node.op, ast.Not):
tgt_level = PR_NOT
return parens(
level, tgt_level, get_unop(node.op) + to_expr(node.operand, tgt_level)
)
BIN_OPS = {
ast.Add: (" + ", PR_ARITH),
ast.Sub: (" - ", PR_ARITH),
ast.Mult: (" * ", PR_TERM),
ast.MatMult: (" @ ", PR_TERM),
ast.Div: (" / ", PR_TERM),
ast.Mod: (" % ", PR_TERM),
ast.LShift: (" << ", PR_SHIFT),
ast.RShift: (" >> ", PR_SHIFT),
ast.BitOr: (" | ", PR_BOR),
ast.BitXor: (" ^ ", PR_BXOR),
ast.BitAnd: (" & ", PR_BAND),
ast.FloorDiv: (" // ", PR_TERM),
ast.Pow: (" ** ", PR_POWER),
}
def _format_binaryop(node: ast.BinOp, level: int) -> str:
tgt_level = PR_FACTOR
op, tgt_level = BIN_OPS[type(node.op)]
rassoc = 0
if isinstance(node.op, ast.Pow):
rassoc = 1
return parens(
level,
tgt_level,
to_expr(node.left, tgt_level + rassoc)
+ op
+ to_expr(node.right, tgt_level + (1 - rassoc)),
)
def _format_subscript(node: ast.Subscript, level: int) -> str:
return f"{to_expr(node.value, PR_ATOM)}[{to_expr(node.slice)}]"
def _format_index(node: ast.Index, level: int) -> str:
return to_expr(node.value, PR_TUPLE)
def _format_yield(node: ast.Yield, level: int) -> str:
if node.value:
return "(yield " + to_expr(node.value) + ")"
return "(yield)"
def _format_yield_from(node: ast.YieldFrom, level: int) -> str:
return "(yield from " + to_expr(node.value) + ")"
def _format_dict(node: ast.Dict, level: int) -> str:
return (
"{"
+ ", ".join(
to_expr(k) + ": " + to_expr(v) for k, v in zip(node.keys, node.values)
)
+ "}"
)
def _format_comprehension(node: ast.comprehension) -> str:
header = " for "
if node.is_async:
header = " async for "
res = (
header
+ to_expr(node.target, PR_TUPLE)
+ " in "
+ to_expr(node.iter, PR_TEST + 1)
)
for if_ in node.ifs:
res += " if " + to_expr(if_, PR_TEST + 1)
return res
def parens(level: int, target_lvl: int, value: str) -> str:
if level > target_lvl:
return f"({value})"
return value
def _format_await(node: ast.Await, level: int):
return parens(level, PR_AWAIT, "await " + to_expr(node.value, PR_ATOM))
def _format_starred(node: ast.Starred, level: int):
return "*" + to_expr(node.value, PR_EXPR)
def _format_boolop(node: ast.BoolOp, level: int) -> str:
if isinstance(node.op, ast.And):
name = " and "
tgt_level = PR_AND
else:
name = " or "
tgt_level = PR_OR
return parens(
level, tgt_level, name.join(to_expr(n, tgt_level + 1) for n in node.values)
)
def _format_arguments(node: ast.arguments) -> str:
res = []
for i, arg in enumerate(node.args):
if i:
res.append(", ")
res.append(arg.arg)
if i < len(node.defaults):
res.append("=")
res.append(to_expr(node.defaults[i]))
if node.vararg or node.kwonlyargs:
if node.args:
res.append(", ")
res.append("*")
vararg = node.vararg
if vararg:
res.append(vararg.arg)
for i, arg in enumerate(node.kwonlyargs):
if res:
res.append(", ")
res.append(arg.arg)
if i < len(node.kw_defaults) and node.kw_defaults[i]:
res.append("=")
res.append(to_expr(node.kw_defaults[i]))
return "".join(res)
def _format_lambda(node: ast.Lambda, level: int) -> str:
value = "lambda "
if not node.args.args:
value = "lambda"
value += _format_arguments(node.args)
value += ": " + to_expr(node.body, PR_TEST)
return parens(level, PR_TEST, value)
def _format_if_exp(node: ast.IfExp, level: int) -> str:
body = to_expr(node.body, PR_TEST + 1)
orelse = to_expr(node.orelse, PR_TEST)
test = to_expr(node.test, PR_TEST + 1)
return parens(level, PR_TEST, f"{body} if {test} else {orelse}")
def _format_set(node: ast.Set, level: int) -> str:
return "{" + ", ".join(to_expr(elt, PR_TEST) for elt in node.elts) + "}"
def _format_comprehensions(nodes: List[ast.comprehension]) -> str:
return "".join(_format_comprehension(n) for n in nodes)
def _format_set_comp(node: ast.SetComp, level: int) -> str:
return "{" + to_expr(node.elt) + _format_comprehensions(node.generators) + "}"
def _format_list_comp(node: ast.ListComp, level: int) -> str:
return "[" + to_expr(node.elt) + _format_comprehensions(node.generators) + "]"
def _format_dict_comp(node: ast.DictComp, level: int) -> str:
return (
"{"
+ to_expr(node.key)
+ ": "
+ to_expr(node.value)
+ _format_comprehensions(node.generators)
+ "}"
)
def _format_gen_exp(node: ast.GeneratorExp, level: int) -> str:
return "(" + to_expr(node.elt) + _format_comprehensions(node.generators) + ")"
def format_fstring_elt(res: List[str], elt: ast.expr, is_format_spec: bool):
if isinstance(elt, ast.Str):
res.append(elt.s)
elif isinstance(elt, ast.Constant):
res.append(elt.value)
elif isinstance(elt, ast.JoinedStr):
res.append(format_joinedstr(elt, PR_TEST, is_format_spec))
elif isinstance(elt, ast.FormattedValue):
expr = to_expr(elt.value, PR_TEST + 1)
if expr.startswith("{"):
# Expression starts with a brace, we need an extra space
res.append("{ ")
else:
res.append("{")
res.append(expr)
conversion = elt.conversion
if conversion is not None and conversion != -1:
res.append("!")
res.append(chr(conversion))
format_spec = elt.format_spec
if format_spec is not None:
res.append(":")
format_fstring_elt(res, format_spec, True)
res.append("}")
def format_joinedstr(node: ast.JoinedStr, level: int, is_format_spec=False) -> str:
res = []
for elt in node.values:
format_fstring_elt(res, elt, is_format_spec)
joined = "".join(res)
if is_format_spec:
return joined
return f"f{repr(joined)}"
def _format_slice(node: ast.Slice, level: int):
res = ""
if node.lower is not None:
res += to_expr(node.lower)
res += ":"
if node.upper is not None:
res += to_expr(node.upper)
if node.step:
res += ":"
res += to_expr(node.step)
return res
def _format_extslice(node: ast.ExtSlice, level: int):
return ", ".join(to_expr(d) for d in node.dims)
def _format_constant(node: ast.Constant, level: int):
if node.value is Ellipsis:
return "..."
return repr(node.value)
_FORMATTERS: Dict[Type, Callable[[Any, int], str]] = {
ast.BoolOp: _format_boolop,
ast.BinOp: _format_binaryop,
ast.UnaryOp: _format_unaryop,
ast.Lambda: _format_lambda,
ast.IfExp: _format_if_exp,
ast.Dict: _format_dict,
ast.Set: _format_set,
ast.GeneratorExp: _format_gen_exp,
ast.ListComp: _format_list_comp,
ast.SetComp: _format_set_comp,
ast.DictComp: _format_dict_comp,
ast.Yield: _format_yield,
ast.YieldFrom: _format_yield_from,
ast.Await: _format_await,
ast.Compare: _format_compare,
ast.Call: _format_call,
ast.Constant: _format_constant,
ast.Num: _format_num,
ast.Str: _format_str,
ast.JoinedStr: format_joinedstr,
ast.FormattedValue: None,
ast.Bytes: lambda node, level: repr(node.s),
ast.Ellipsis: lambda node, level: "...",
ast.NameConstant: _format_nameconstant,
ast.Attribute: _format_attribute,
ast.Subscript: _format_subscript,
ast.Starred: _format_starred,
ast.Name: _format_name,
ast.List: _format_list,
ast.Tuple: _format_tuple,
ast.Slice: _format_slice,
ast.ExtSlice: _format_extslice,
ast.Index: _format_index,
}
def to_expr(node: Optional[ast.AST], level=PR_TEST) -> str:
if node is None:
return ""
formatter = _FORMATTERS.get(type(node))
if formatter is not None:
return formatter(node, level)
return "<unsupported node: " + type(node).__name__ + ">"