library/_compiler.py (251 lines of code) (raw):
# Copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com)
import ast
from ast import AST
from compiler import compile as compiler_compile
from compiler.optimizer import BIN_OPS, is_const, get_const_value
from compiler.py38.optimizer import AstOptimizer38
from compiler.pyassem import PyFlowGraph38
from compiler.pycodegen import Python38CodeGenerator
from compiler.symbols import SymbolVisitor
from compiler.visitor import ASTVisitor, walk
import _compiler_opcode as opcodepyro
def should_rewrite_printf(node):
return isinstance(node.left, ast.Str) and isinstance(node.op, ast.Mod)
def create_conversion_call(name, value):
method = ast.Attribute(ast.Str(""), name, ast.Load())
return ast.Call(method, args=[value], keywords=[])
def try_constant_fold_mod(format_string, right):
r = get_const_value(right)
return ast.Str(format_string.__mod__(r))
class AstOptimizerPyro(AstOptimizer38):
def rewrite_str_mod(self, left, right): # noqa: C901
format_string = left.s
try:
if is_const(right):
return try_constant_fold_mod(format_string, right)
# Try and collapse the whole expression into a string
const_tuple = self.makeConstTuple(right.elts)
if const_tuple:
return ast.Str(format_string.__mod__(const_tuple.value))
except Exception:
pass
n_specifiers = 0
i = 0
length = len(format_string)
while i < length:
i = format_string.find("%", i)
if i == -1:
break
ch = format_string[i]
i += 1
if i >= length:
# Invalid format string ending in a single percent
return None
ch = format_string[i]
i += 1
if ch == "%":
# Break the string apart at '%'
continue
elif ch == "(":
# We don't support dict lookups and may get confused from
# inner '%' chars
return None
n_specifiers += 1
rhs = right
if isinstance(right, ast.Tuple):
rhs_values = rhs.elts
num_values = len(rhs_values)
else:
# If RHS is not a tuple constructor, then we only support the
# situation with a single format specifier in the string, by
# normalizing `rhs` to a one-element tuple:
# `_mod_check_single_arg(rhs)[0]`
rhs_values = None
if n_specifiers != 1:
return None
num_values = 1
i = 0
value_idx = 0
segment_begin = 0
strings = []
while i < length:
i = format_string.find("%", i)
if i == -1:
break
ch = format_string[i]
i += 1
segment_end = i - 1
if segment_end - segment_begin > 0:
substr = format_string[segment_begin:segment_end]
strings.append(ast.Str(substr))
if i >= length:
return None
ch = format_string[i]
i += 1
# Parse flags and width
spec_begin = i - 1
have_width = False
while True:
if ch == "0":
# TODO(matthiasb): Support ' ', '+', '#', etc
# They mostly have the same meaning. However they can
# appear in any order here but must follow stricter
# conventions in f-strings.
if i >= length:
return None
ch = format_string[i]
i += 1
continue
break
if "1" <= ch <= "9":
have_width = True
if i >= length:
return None
ch = format_string[i]
i += 1
while "0" <= ch <= "9":
if i >= length:
return None
ch = format_string[i]
i += 1
spec_str = ""
if i - 1 - spec_begin > 0:
spec_str = format_string[spec_begin : i - 1]
if ch == "%":
# Handle '%%'
segment_begin = i - 1
continue
# Handle remaining supported cases that use a value from RHS
if rhs_values is not None:
if value_idx >= num_values:
return None
value = rhs_values[value_idx]
else:
# We have a situation like `"%s" % x` without tuple on RHS.
# Transform to: f"{''._mod_check_single_arg(x)[0]}"
converted = create_conversion_call("_mod_check_single_arg", rhs)
value = ast.Subscript(converted, ast.Index(ast.Num(0)), ast.Load())
value_idx += 1
if ch in "sra":
# Rewrite "%s" % (x,) to f"{x!s}"
if have_width:
# Need to explicitly specify alignment because `%5s`
# aligns right, while `f"{x:5}"` aligns left.
spec_str = ">" + spec_str
format_spec = ast.Str(spec_str) if spec_str else None
formatted = ast.FormattedValue(value, ord(ch), format_spec)
strings.append(formatted)
elif ch in "diu":
# Rewrite "%d" % (x,) to f"{''._mod_convert_number_int(x)}".
# Calling a method on the empty string is a hack to access a
# well-known function regardless of the surrounding
# environment.
converted = create_conversion_call("_mod_convert_number_int", value)
format_spec = ast.Str(spec_str) if spec_str else None
formatted = ast.FormattedValue(converted, -1, format_spec)
strings.append(formatted)
elif ch in "xXo":
# Rewrite "%x" % (v,) to f"{''._mod_convert_number_index(v):x}".
# Calling a method on the empty string is a hack to access a
# well-known function regardless of the surrounding
# environment.
converted = create_conversion_call("_mod_convert_number_index", value)
format_spec = ast.Str(spec_str + ch)
formatted = ast.FormattedValue(converted, -1, format_spec)
strings.append(formatted)
else:
return None
# Begin next segment after specifier
segment_begin = i
if value_idx != num_values:
return None
segment_end = length
if segment_end - segment_begin > 0:
substr = format_string[segment_begin:segment_end]
strings.append(ast.Str(substr))
return ast.JoinedStr(strings)
def visitBinOp(self, node: ast.BinOp) -> ast.expr:
left = self.visit(node.left)
right = self.visit(node.right)
if is_const(left) and is_const(right):
handler = BIN_OPS.get(type(node.op))
if handler is not None:
lval = get_const_value(left)
rval = get_const_value(right)
try:
return ast.copy_location(ast.Constant(handler(lval, rval)), node)
except Exception:
pass
if should_rewrite_printf(node):
result = self.rewrite_str_mod(left, right)
if result:
return self.visit(result)
return self.update_node(node, left=left, right=right)
class PyroFlowGraph(PyFlowGraph38):
opcode = opcodepyro.opcode
class ComprehensionRenamer(ASTVisitor):
def __init__(self, scope):
super().__init__()
# We need a prefix that is unique per-scope for each renaming round.
index = getattr(scope, "last_comprehension_rename_index", -1) + 1
scope.last_comprehension_rename_index = index
self.prefix = f"_gen{str(index) if index > 0 else ''}$"
self.new_names = {}
self.is_target = False
def visitName(self, node):
if self.is_target and isinstance(node.ctx, (ast.Store, ast.Del)):
name = node.id
new_name = self.prefix + name
self.new_names[name] = new_name
node.id = new_name
else:
new_name = self.new_names.get(node.id)
if new_name is not None:
node.id = new_name
def visitarg(self, node):
new_name = self.new_names.get(node.arg)
if new_name is not None:
node.arg = new_name
class PyroSymbolVisitor(SymbolVisitor):
def visitDictCompListCompSetComp(self, node, scope):
# Check for unexpected assignments.
scope.comp_iter_expr += 1
self.visit(node.generators[0].iter, scope)
scope.comp_iter_expr -= 1
renamer = ComprehensionRenamer(scope)
is_outer = True
for gen in node.generators:
renamer.visit(gen.iter)
renamer.is_target = True
renamer.visit(gen.target)
renamer.is_target = False
for if_node in gen.ifs:
renamer.visit(if_node)
self.visitcomprehension(gen, scope, is_outer)
is_outer = False
if isinstance(node, ast.DictComp):
renamer.visit(node.value)
renamer.visit(node.key)
self.visit(node.value, scope)
self.visit(node.key, scope)
else:
renamer.visit(node.elt)
self.visit(node.elt, scope)
visitDictComp = visitDictCompListCompSetComp
visitListComp = visitDictCompListCompSetComp
visitSetComp = visitDictCompListCompSetComp
class PyroCodeGenerator(Python38CodeGenerator):
flow_graph = PyroFlowGraph
@classmethod
def make_code_gen(
cls,
name: str,
tree: AST,
filename: str,
flags: int,
optimize: int,
peephole_enabled: bool = True,
ast_optimizer_enabled: bool = True,
):
if ast_optimizer_enabled:
tree = cls.optimize_tree(optimize, tree)
s = PyroSymbolVisitor()
walk(tree, s)
graph = cls.flow_graph(
name, filename, s.scopes[tree], peephole_enabled=peephole_enabled
)
code_gen = cls(None, tree, s, graph, flags, optimize)
walk(tree, code_gen)
return code_gen
@classmethod
def optimize_tree(cls, optimize: int, tree: ast.AST):
return AstOptimizerPyro(optimize=optimize > 0).visit(tree)
def defaultEmitCompare(self, op):
if isinstance(op, ast.Is):
self.emit("COMPARE_IS")
elif isinstance(op, ast.IsNot):
self.emit("COMPARE_IS_NOT")
else:
self.emit("COMPARE_OP", self._cmp_opcode[type(op)])
def visitListComp(self, node):
self.emit("BUILD_LIST")
self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))
def visitSetComp(self, node):
self.emit("BUILD_SET")
self.compile_comprehension_body(node.generators, 0, node.elt, None, type(node))
def visitDictComp(self, node):
self.emit("BUILD_MAP")
self.compile_comprehension_body(
node.generators, 0, node.key, node.value, type(node)
)
def compile(source, filename, mode, flags, dont_inherit, optimize):
return compiler_compile(
source, filename, mode, flags, None, optimize, PyroCodeGenerator
)