library/compiler/optimizer.py (162 lines of code) (raw):

# Portions copyright (c) Facebook, Inc. and its affiliates. (http://www.facebook.com) # pyre-unsafe import ast import operator import sys from ast import Bytes, Constant, Ellipsis, NameConstant, Num, Str, cmpop, copy_location from typing import Dict, Iterable, Optional, Type from .peephole import safe_lshift, safe_mod, safe_multiply, safe_power from .visitor import ASTRewriter def is_const(node): return isinstance(node, (Constant, Num, Str, Bytes, Ellipsis, NameConstant)) def get_const_value(node): if isinstance(node, (Constant, NameConstant)): return node.value elif isinstance(node, Num): return node.n elif isinstance(node, (Str, Bytes)): return node.s elif isinstance(node, Ellipsis): return ... raise TypeError("Bad constant value") class Py37Limits: MAX_INT_SIZE = 128 MAX_COLLECTION_SIZE = 256 MAX_STR_SIZE = 4096 MAX_TOTAL_ITEMS = 1024 UNARY_OPS = { ast.Invert: operator.invert, ast.Not: operator.not_, ast.UAdd: operator.pos, ast.USub: operator.neg, } INVERSE_OPS: Dict[Type[cmpop], Type[cmpop]] = { ast.Is: ast.IsNot, ast.IsNot: ast.Is, ast.In: ast.NotIn, ast.NotIn: ast.In, } BIN_OPS = { ast.Add: operator.add, ast.Sub: operator.sub, ast.Mult: lambda l, r: safe_multiply(l, r, Py37Limits), ast.Div: operator.truediv, ast.FloorDiv: operator.floordiv, ast.Mod: lambda l, r: safe_mod(l, r, Py37Limits), ast.Pow: lambda l, r: safe_power(l, r, Py37Limits), ast.LShift: lambda l, r: safe_lshift(l, r, Py37Limits), ast.RShift: operator.rshift, ast.BitOr: operator.or_, ast.BitXor: operator.xor, ast.BitAnd: operator.and_, } IS_PY38_ABOVE = sys.version_info >= (3, 8) class AstOptimizer(ASTRewriter): def __init__(self, optimize: bool = False): super().__init__() self.optimize = optimize def visitUnaryOp(self, node: ast.UnaryOp) -> ast.expr: op = self.visit(node.operand) if is_const(op): conv = UNARY_OPS[type(node.op)] val = get_const_value(op) try: return copy_location(Constant(conv(val)), node) except Exception: pass elif ( isinstance(node.op, ast.Not) and isinstance(op, ast.Compare) and len(op.ops) == 1 ): cmp_op = op.ops[0] new_op = INVERSE_OPS.get(type(cmp_op)) if new_op is not None: return self.update_node(op, ops=[new_op()]) return self.update_node(node, operand=op) def visitBinOp(self, node: ast.BinOp) -> ast.expr: left = self.visit(node.left) right = self.visit(node.right) if is_const(left) and is_const(right): handler = BIN_OPS.get(type(node.op)) if handler is not None: lval = get_const_value(left) rval = get_const_value(right) try: return copy_location(Constant(handler(lval, rval)), node) except Exception: pass return self.update_node(node, left=left, right=right) def makeConstTuple(self, elts: Iterable[ast.expr]) -> Optional[Constant]: if all(is_const(elt) for elt in elts): return Constant(tuple(get_const_value(elt) for elt in elts)) return None def visitTuple(self, node: ast.Tuple) -> ast.expr: elts = self.walk_list(node.elts) if isinstance(node.ctx, ast.Load): res = self.makeConstTuple(elts) if res is not None: return copy_location(res, node) return self.update_node(node, elts=elts) def visitSubscript(self, node: ast.Subscript) -> ast.expr: value = self.visit(node.value) slice = self.visit(node.slice) if ( isinstance(node.ctx, ast.Load) and is_const(value) and isinstance(slice, ast.Index) and is_const(slice.value) ): try: return copy_location( Constant(get_const_value(value)[get_const_value(slice.value)]), node ) except Exception: pass return self.update_node(node, value=value, slice=slice) def _visitIter(self, node: ast.expr) -> ast.expr: if isinstance(node, ast.List): elts = self.walk_list(node.elts) res = self.makeConstTuple(elts) if res is not None: return copy_location(res, node) if IS_PY38_ABOVE and not any(isinstance(e, ast.Starred) for e in elts): return self.update_node(ast.Tuple(elts=elts, ctx=node.ctx)) return self.update_node(node, elts=elts) elif isinstance(node, ast.Set): elts = self.walk_list(node.elts) res = self.makeConstTuple(elts) if res is not None: return copy_location(Constant(frozenset(res.value)), node) return self.update_node(node, elts=elts) return self.generic_visit(node) def visitcomprehension(self, node: ast.comprehension) -> ast.comprehension: target = self.visit(node.target) iter = self.visit(node.iter) ifs = self.walk_list(node.ifs) iter = self._visitIter(iter) return self.update_node(node, target=target, iter=iter, ifs=ifs) def visitFor(self, node: ast.For) -> ast.For: target = self.visit(node.target) iter = self.visit(node.iter) body = self.walk_list(node.body) orelse = self.walk_list(node.orelse) iter = self._visitIter(iter) return self.update_node( node, target=target, iter=iter, body=body, orelse=orelse ) def visitCompare(self, node: ast.Compare) -> ast.expr: left = self.visit(node.left) comparators = self.walk_list(node.comparators) if isinstance(node.ops[-1], (ast.In, ast.NotIn)): new_iter = self._visitIter(comparators[-1]) if new_iter is not None and new_iter is not comparators[-1]: comparators = list(comparators) comparators[-1] = new_iter return self.update_node(node, left=left, comparators=comparators) def visitName(self, node: ast.Name): if node.id == "__debug__": return copy_location(Constant(not self.optimize), node) return self.generic_visit(node) def visitAssert(self, node: ast.Assert): if self.optimize: # Skip asserts if we're optimizing return None return self.generic_visit(node)