python/graphscope/analytical/udf/compile.py (1,014 lines of code) (raw):

#!/usr/bin/env python3 # -*- coding: utf-8 -*- # # Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """An almost-complete Python to Cython AST transformer, with injected GRAPHSCOPE-specific translation. Python AST nodes are translated to corresponding Cython AST nodes as it is, except: 1. for top-level method, a Cython type annotation is attached to the function signature, for example, .. code:: python @graphscope.analytical.udf.peval('sssp') def PEval(frag, context): ... will be translated as: .. code:: cython cdef public void IncEval(Fragment *frag, ComputeContext *context): ... it will make Cython understand what we really want and generate proper Cpp code. 2. for invokation on methods inside :code:`graphscope.analytical.udf.core`, we generate proper special :code:`cdef` definitions, or proper Cpp invokations, just like :code:`cython.declare`, for example, .. code:: python heap = graphscope.analytical.udf.heap((float, 'node')) modified = lang.vector(bool, [False for _ in range(inner_vertices.size())]) will be translated as: .. code:: cython cdef priority_queue[pair[double, NodeT]] heap cdef vector[bool] modified([False for _ in range(inner_vertices.size())]) Note that :code:`float` in Python is mapped to :code:`double` in Cython (further in Cpp code). More specifically, we define a series of placeholders in module :code:`graphscope.analytical.udf.core`, which cannot be executed in pure python mode. The :code:`graphscope.analytical.udf.xxx` decorators will translate those ordinary *"assignment and call"* into a :code:`cdef` node in Cython AST. """ import ast import copy import functools import inspect import textwrap import types import warnings from Cython.CodeWriter import CodeWriter from Cython.Compiler import Builtin from Cython.Compiler import StringEncoding from Cython.Compiler.ExprNodes import * from Cython.Compiler.ModuleNode import * from Cython.Compiler.Nodes import * from graphscope.analytical.udf.patch import patch_cython_codewriter from graphscope.analytical.udf.utils import CType from graphscope.analytical.udf.utils import ExpectFuncDef from graphscope.analytical.udf.utils import LinesWrapper from graphscope.analytical.udf.utils import PregelAggregatorType from graphscope.analytical.udf.utils import ProgramModel from graphscope.framework.errors import check_argument class GRAPECompiler(ast.NodeVisitor): def __init__(self, name, vd_type, md_type, program_model=ProgramModel.Pregel): """ Args: name: str. The name of class. vd_type: str. The type of the data stored in vertex. md_type: str. The type of the message. program_model: ProgramModel. 'Pregel' or 'PIE' """ self._name = name self._vd_type = vd_type self._md_type = md_type self._program_model = program_model # store aggregate function indexed by name self.__registered_aggregators = {} self.__globals = {} self.__func_params_name_list = [] self.__pyx_header = LinesWrapper() def set_pregel_program_model(self): self._program_model = ProgramModel.Pregel def set_pie_program_model(self): self._program_model = ProgramModel.PIE def parse(self, source): """Parse source into cython module node object. source: str The source code may represent a statement or expression. Raises: RuntimeError: unsupported ast trans from python to cython. """ tree = ast.parse(textwrap.dedent(source)) # associate `parent` reference to every node for node in ast.walk(tree): for child in ast.iter_child_nodes(node): setattr(child, "__parent__", node) cyast = self.visit(tree) return cyast def run(self, func_or_ast, pyx_header): self.__pyx_header = pyx_header # we already has a AST: just run it if isinstance(func_or_ast, ast.AST): cyast = self.visit(func_or_ast) else: check_argument(isinstance(func_or_ast, types.FunctionType)) # ignore varargs and keywords self.__func_params_name_list = inspect.getfullargspec(func_or_ast).args self.__globals = func_or_ast.__globals__ cyast = self.parse(inspect.getsource(func_or_ast)) writer = patch_cython_codewriter(CodeWriter()) cycode = "\n".join(writer.write(cyast).lines) return cycode def compile(self, source): """Compile source into cython code.""" cyast = self.parse(source) writer = patch_cython_codewriter(CodeWriter()) return "\n".join(writer.write(cyast).lines) def make_plain_arg(self, name, arg_loc): return CArgDeclNode( arg_loc, base_type=CSimpleBaseTypeNode( arg_loc, name=None, is_basic_c_type=0, signed=1, longness=0, is_self_arg=False, ), declarator=CNameDeclaratorNode(arg_loc, name=name), not_none=0, or_none=0, default=None, annotation=None, ) def make_value_arg(self, value_type, name, arg_loc): return CArgDeclNode( arg_loc, base_type=CSimpleBaseTypeNode( arg_loc, name=value_type, is_basic_c_type=0, signed=1, longness=0, is_self_arg=False, ), declarator=CNameDeclaratorNode(arg_loc, name=name), not_none=0, or_none=0, default=None, annotation=None, ) def make_ptr_arg(self, ptr_type, name, arg_loc): return CArgDeclNode( arg_loc, base_type=CSimpleBaseTypeNode( arg_loc, name=ptr_type, is_basic_c_type=0, signed=1, longness=0, is_self_arg=False, ), declarator=CPtrDeclaratorNode( arg_loc, base=CNameDeclaratorNode(arg_loc, name=name) ), not_none=0, or_none=0, default=None, annotation=None, ) def make_ref_arg(self, ref_type, name, arg_loc): return CArgDeclNode( arg_loc, base_type=CSimpleBaseTypeNode( arg_loc, name=ref_type, is_basic_c_type=0, signed=1, longness=0, complex=0, is_self_arg=False, templates=None, ), declarator=CReferenceDeclaratorNode( arg_loc, base=CNameDeclaratorNode(arg_loc, name=name) ), not_none=0, or_none=0, default=None, annotation=None, ) def make_template_arg( self, value_type, value_tpls, name, arg_loc, use_ptr=False, use_ref=False ): def mk_tpl_arg(n): return CComplexBaseTypeNode( arg_loc, base_type=CSimpleBaseTypeNode( arg_loc, name=n, is_basic_c_type=0, signed=1, longness=0, is_self_arg=False, ), declarator=CNameDeclaratorNode( arg_loc, name="", cname=None, default=None ), ) tpl_type = TemplatedTypeNode( arg_loc, positional_args=[mk_tpl_arg(n) for n in value_tpls], keyword_args=DictNode(arg_loc, key_value_pairs=[]), base_type_node=CSimpleBaseTypeNode( arg_loc, name=value_type, is_basic_c_type=0, signed=1, longness=0, is_self_arg=False, ), ) if use_ptr: declarator = CPtrDeclaratorNode( arg_loc, base=CNameDeclaratorNode(arg_loc, name=name) ) elif use_ref: declarator = CReferenceDeclaratorNode( arg_loc, base=CNameDeclaratorNode(arg_loc, name=name) ) else: declarator = CNameDeclaratorNode(arg_loc, name=name) return CArgDeclNode( arg_loc, base_type=tpl_type, declarator=declarator, not_none=0, or_none=0, default=None, annotation=None, ) def loc(self, node): return ["", 0, 0] def generic_visit(self, node): raise NotImplementedError("AST node %s is not supported yet" % node) def visit_Module(self, node): body = self.visit(node.body[0]) return ModuleNode(self.loc(node), body=body) def visit_ImportFrom(self, node): raise RuntimeError("ImportFrom is not supported yet.") def visit_Import(self, node): raise RuntimeError("Import is not supported yet.") def visit_ClassDef(self, node): raise RuntimeError("Class definition is not supported yet.") def visit_JoinedStr(self, node): raise RuntimeError("Joinedstr is not supported yet.") def visit_Constant(self, node): if isinstance(node.value, int): return IntNode(self.loc(node), value=str(node.value)) if isinstance(node.value, float): # We won't have c float, we map floating types to double return FloatNode(self.loc(node), value=str(node.value)) if isinstance(node.value, str): if node.kind == "u": return UnicodeNode(self.loc(node), value=node.value, bytes_value=None) return StringNode( self.loc(node), value=node.value, unicode_value=StringEncoding.EncodedString(node.value), ) if ( isinstance(Ellipsis, type) and isinstance(node.value, Ellipsis) or isinstance(node.value, type(Ellipsis)) ): return EllipsisNode(self.loc(node)) if isinstance(node.value, bytes): return BytesNode(self.loc(node), value=node.s) if node.value is None: return NoneNode(self.loc(node)) raise NotImplementedError("Unknown constant value: %s" % node) def visit_Num(self, node): if isinstance(node.n, int): return IntNode(self.loc(node), value=str(node.n)) if isinstance(node.n, float): return FloatNode(self.loc(node), value=str(node.n)) if isinstance(node.n, complex): raise NotImplementedError("Not support complex constant yet") raise NotImplementedError("Unknown constant value: %s" % node) def visit_Str(self, node): return StringNode( self.loc(node), value=node.s, unicode_value=StringEncoding.EncodedString(node.s), ) def visit_Bytes(self, node): return BytesNode(self.loc(node), value=node.s) def visit_List(self, node): return ListNode(self.loc(node), args=[self.visit(elt) for elt in node.elts]) def visit_Tuple(self, node): return TupleNode(self.loc(node), args=[self.visit(elt) for elt in node.elts]) def visit_Set(self, node): return SetNode(self.loc(node), args=[self.visit(elt) for elt in node.elts]) def visit_Dict(self, node): kvs = [ DictItemNode(self.loc(node), key=self.visit(k), value=self.visit(v)) for k, v in zip(node.keys, node.values) ] return DictNode(self.loc(node), key_value_pairs=kvs) def visit_Ellipsis(self, node): return EllipsisNode(self.loc(node)) def visit_NameConstant(self, node): if node.value in [True, False]: return BoolNode(self.loc(node), value=node.value) return NoneNode(self.loc(node)) def visit_Name(self, node): return NameNode(self.loc(node), name=node.id) def visit_Expr(self, node): expr = self.visit(node.value) if isinstance(expr, CVarDefNode): return expr return ExprStatNode(self.loc(node), expr=expr) def visit_UnaryOp(self, node): if isinstance(node.op, ast.UAdd): return UnaryPlusNode( self.loc(node), operator="+", operand=self.visit(node.operand) ) if isinstance(node.op, ast.USub): return UnaryMinusNode( self.loc(node), operator="-", operand=self.visit(node.operand) ) if isinstance(node.op, ast.Not): return NotNode(self.loc(node), operand=self.visit(node.operand)) if isinstance(node.op, ast.Invert): return TildeNode( self.loc(node), operator="~", operand=self.visit(node.operand) ) def visit_UAdd(self, node): return "+" def visit_USub(self, node): return "-" def visit_Not(self, node): return "not" def visit_Invert(self, node): return "invert" def visit_BinOp(self, node): lhs = self.visit(node.left) rhs = self.visit(node.right) op_mapping = { ast.Add: (AddNode, "+"), ast.Sub: (SubNode, "-"), ast.Mult: (MulNode, "*"), ast.Div: (DivNode, "/"), ast.FloorDiv: (DivNode, "//"), ast.Mod: (ModNode, "%"), ast.Pow: (PowNode, "**"), ast.MatMult: (MatMultNode, "@"), ast.LShift: (IntBinopNode, "<<"), ast.RShift: (IntBinopNode, ">>"), ast.BitOr: (IntBinopNode, "|"), ast.BitXor: (IntBinopNode, "^"), ast.BitAnd: (IntBinopNode, "&"), } op_type, op = op_mapping[type(node.op)] return op_type(self.loc(op), operator=op, operand1=lhs, operand2=rhs) def visit_Add(self, node): return "+" def visit_Sub(self, node): return "-" def visit_Mult(self, node): return "*" def visit_Div(self, node): return "/" def visit_FloorDiv(self, node): return "//" def visit_Mod(self, node): return "%" def visit_Pow(self, node): return "**" def visit_LShift(self, node): return "<<" def visit_RShift(self, node): return ">>" def visit_BitOr(self, node): return "|" def visit_BitXor(self, node): return "^" def visit_BitAnd(self, node): return "&" def visit_MatMult(self, node): return "@" def visit_AnnAssign(self, node): annotation = NameNode(self.loc(node), name=node.annotation.id) lhs = NameNode(self.loc(node), name=node.target.id, annotation=annotation) rhs = self.visit(node.value) return SingleAssignmentNode(self.loc(node), lhs=lhs, rhs=rhs) def visit_BoolOp(self, node): return BoolBinopNode( self.loc(node), operator=self.visit(node.op), operand1=self.visit(node.values[0]), operand2=self.visit(node.values[1]), ) def visit_And(self, node): return "and" def visit_Or(self, node): return "or" def visit_Compare(self, node): operator = self.visit(node.ops[0]) operand1 = self.visit(node.left) operand2 = self.visit(node.comparators[0]) if len(node.comparators) == 1: # single comparison return PrimaryCmpNode( self.loc(node), operator=operator, operand1=operand1, operand2=operand2 ) # multiple continuous comparison cascade_node = CascadedCmpNode( self.loc(node), operator=self.visit(node.ops[-1]), operand2=self.visit(node.comparators[-1]), ) for op, comparator in zip(node.ops[-2:0:-1], node.comparators[-2:0:-1]): cascade_node = CascadedCmpNode( self.loc(node), operator=self.visit(op), operand2=self.visit(comparator), cascade=cascade_node, ) return PrimaryCmpNode( self.loc(node), operator=operator, operand1=operand1, operand2=operand2, cascade=cascade_node, ) def visit_Eq(self, node): return "==" def visit_NotEq(self, node): return "!=" def visit_Lt(self, node): return "<" def visit_LtE(self, node): return "<=" def visit_Gt(self, node): return ">" def visit_GtE(self, node): return ">=" def visit_Is(self, node): return "is" def visit_IsNot(self, node): return "is not" def visit_In(self, node): return "in" def visit_NotIn(self, node): return "not in" def __flatten_func_name(self, name): if isinstance(name, ast.Name): return [name.id] if isinstance(name, ast.Attribute): return self.__flatten_func_name(name.value) + [name.attr] return [] def __is_graphscope_api_call(self, node): flat_func_name = self.__flatten_func_name(node.func) if len(flat_func_name) == 0: return False if flat_func_name[0] in self.__func_params_name_list: return True # check from graphscope module cascade = self.__globals.get(flat_func_name[0]) if cascade is None: return False for n in flat_func_name[1:]: if cascade is None or not hasattr(cascade, n): return False cascade = getattr(cascade, n) return cascade.__module__ == "graphscope.analytical.udf.types" def __visit_GraphScopeAPICall(self, node): full_func_name = self.__flatten_func_name(node.func) obj = full_func_name[0] name = node.func.attr if obj == "graphscope": # graphscope.declare() if name == "declare": var = node.args[1].id var_type = node.args[0].attr return CVarDefNode( self.loc(node), base_type=CSimpleBaseTypeNode( self.loc(node), name=var_type, module_path=[], is_basic_c_type=0, signed=1, ), declarators=[CNameDeclaratorNode(self.loc(node), name=var)], visibility="private", ) elif obj == "context" and name == "register_aggregator": # context.register_aggregator() args = node.args if len(args) != 2: raise ValueError("Params within register_aggregator incorrect.") if ( isinstance(args[1], ast.Attribute) and args[1].value.id == "PregelAggregatorType" ): self.__registered_aggregators[str(args[0].s)] = args[1].attr return SimpleCallNode( self.loc(node), function=self.visit(node.func), args=[self.visit(arg) for arg in node.args], ) elif obj == "context" and name == "aggregate": # context.aggregate() args = node.args if len(args) != 2: raise ValueError("Params within aggregate incorrect.") if str(args[0].s) not in self.__registered_aggregators.keys(): raise KeyError( "Aggregator %s not exist, you may want to register first." % str(args[0].s) ) ctype = PregelAggregatorType.extract_ctype( self.__registered_aggregators[str(args[0].s)] ) return SimpleCallNode( self.loc(node), function=IndexNode( self.loc(node), base=AttributeNode( self.loc(node), obj=NameNode(self.loc(node), name=obj), attribute=name, ), index=NameNode(self.loc(node), name=str(ctype)), ), args=[self.visit(arg) for arg in node.args], ) elif obj == "context" and name == "get_aggregated_value": # context.get_aggregated_value() args = node.args if len(args) != 1: raise ValueError("Params within get_aggregated_value incorrect.") if str(args[0].s) not in self.__registered_aggregators.keys(): raise KeyError( "Aggregator %s not exist, you may want to register first." % str(args[0].s) ) ctype = PregelAggregatorType.extract_ctype( self.__registered_aggregators[str(args[0].s)] ) return SimpleCallNode( self.loc(node), function=IndexNode( self.loc(node), base=AttributeNode( self.loc(node), obj=NameNode(self.loc(node), name=obj), attribute=name, ), index=NameNode(self.loc(node), name=str(ctype)), ), args=[self.visit(arg) for arg in node.args], ) elif obj == "context" and full_func_name[1] == "math": mnode = copy.copy(node) mnode.func = ast.Attribute(value=ast.Name(id="math"), attr=name) return self.visit(mnode) else: return SimpleCallNode( self.loc(node), function=self.visit(node.func), args=[self.visit(arg) for arg in node.args], ) def visit_Call(self, node): if self.__is_graphscope_api_call(node): return self.__visit_GraphScopeAPICall(node) if not node.keywords: return SimpleCallNode( self.loc(node), function=self.visit(node.func), args=[self.visit(arg) for arg in node.args], ) # with kwargs param return GeneralCallNode( self.loc(node), function=self.visit(node.func), positional_args=TupleNode( self.loc(node), args=[self.visit(arg) for arg in node.args] ), # keyword_args=DictNode(self.loc(node), key_value_pairs=[])) keyword_args=self._visit_keywords(node.keywords), ) def _visit_keywords(self, node): kvs = [] for keyword in node: kvs.append(self.visit_keyword(keyword)) return DictNode(self.loc(node), key_value_pairs=kvs, reject_duplicates=True) def visit_keyword(self, node): key = IdentifierStringNode(self.loc(node), value=node.arg) return DictItemNode(self.loc(node), key=key, value=self.visit(node.value)) def visit_IfExp(self, node): return CondExprNode( self.loc(node), test=self.visit(node.test), true_val=self.visit(node.body), false_val=self.visit(node.orelse), ) def visit_Attribute(self, node): full_attr_name = self.__flatten_func_name(node) if full_attr_name[0] == "context" and full_attr_name[1] == "math": mnode = copy.copy(node) mnode.value = ast.Name(id="math") return self.visit(mnode) return AttributeNode( self.loc(node), obj=self.visit(node.value), attribute=node.attr ) def visit_Subscript(self, node): return IndexNode( self.loc(node), base=self.visit(node.value), index=self.visit(node.slice) ) def visit_Index(self, node): return self.visit(node.value) def visit_Slice(self, node): start = ( NoneNode(self.loc(node)) if node.lower is None else self.visit(node.lower) ) stop = ( NoneNode(self.loc(node)) if node.upper is None else self.visit(node.upper) ) step = NoneNode(self.loc(node)) if node.step is None else self.visit(node.step) return SliceNode(self.loc(node), start=start, stop=stop, step=step) def visit_ExtSlice(self, node): return TupleNode(self.loc(node), args=[self.visit(dim) for dim in node.dims]) def visit_ListComp(self, node): check_argument(len(node.generators) == 1) # has if node or not has_if = True if node.generators[0].ifs else False expression_value = self.visit(node.elt) generator = node.generators[0] iter_value = IteratorNode( self.loc(generator.iter), sequence=self.visit(generator.iter) ) comp_node = ComprehensionAppendNode(self.loc(generator), expr=expression_value) if has_if: check_argument(len(node.generators[0].ifs) == 1) # construct IfStatNode condition = self.visit(node.generators[0].ifs[0]) body = comp_node if_stat_node = IfStatNode( self.loc(node), if_clauses=[ IfClauseNode(self.loc(node), condition=condition, body=body) ], else_clause=None, ) loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=if_stat_node, else_clause=None, is_async=False, ) else: loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=comp_node, else_clause=None, is_async=False, ) return ComprehensionNode( self.loc(node), loop=loop, append=comp_node, type=Builtin.list_type, has_local_scope=True, ) def visit_SetComp(self, node): assert len(node.generators) == 1 # has if node or not has_if = True if node.generators[0].ifs else False expression_value = self.visit(node.elt) generator = node.generators[0] iter_value = IteratorNode( self.loc(generator.iter), sequence=self.visit(generator.iter) ) comp_node = ComprehensionAppendNode(self.loc(generator), expr=expression_value) if has_if: assert len(node.generators[0].ifs) == 1 # construct IfStatNode condition = self.visit(node.generators[0].ifs[0]) body = comp_node if_stat_node = IfStatNode( self.loc(node), if_clauses=[ IfClauseNode(self.loc(node), condition=condition, body=body) ], else_clause=None, ) loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=if_stat_node, else_clause=None, is_async=False, ) else: loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=comp_node, else_clause=None, is_async=False, ) return ComprehensionNode( self.loc(node), loop=loop, append=comp_node, type=Builtin.set_type ) def visit_DictComp(self, node): assert len(node.generators) == 1 # has if node or not has_if = True if node.generators[0].ifs else False generator = node.generators[0] iter_value = IteratorNode( self.loc(generator.iter), sequence=self.visit(generator.iter) ) comp_node = DictComprehensionAppendNode( self.loc(generator), key_expr=self.visit(node.key), value_expr=self.visit(node.value), ) if has_if: assert len(node.generators[0].ifs) == 1 # construct IfStatNode condition = self.visit(node.generators[0].ifs[0]) body = comp_node if_stat_node = IfStatNode( self.loc(node), if_clauses=[ IfClauseNode(self.loc(node), condition=condition, body=body) ], else_clause=None, ) loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=if_stat_node, else_clause=None, is_async=False, ) else: loop = ForInStatNode( self.loc(node), target=self.visit(generator.target), iterator=iter_value, body=comp_node, else_clause=None, is_async=False, ) return ComprehensionNode( self.loc(node), loop=loop, append=comp_node, type=Builtin.dict_type ) def visit_Assign(self, node): # `tuple` represents a multiple assign assert len(node.targets) == 1 if ( hasattr(node.targets[0], "id") and node.targets[0].id in self.__func_params_name_list ): raise RuntimeError("Can't assign to internal variables.") lhs = self.visit(node.targets[0]) rhs = self.visit(node.value) return SingleAssignmentNode(self.loc(node), lhs=lhs, rhs=rhs) def visit_AugAssign(self, node): return InPlaceAssignmentNode( self.loc(node), operator=self.visit(node.op), lhs=self.visit(node.target), rhs=self.visit(node.value), ) def visit_Raise(self, node): return RaiseStatNode( self.loc(node), exc_type=self.visit(node.exc), exc_value=None, exc_tb=None, cause=None if node.cause is None else self.visit(node.cause), ) def visit_ExceptHandler(self, node): if node.type: pattern = [self.visit(node.type)] if node.name: target = NameNode(self.loc(node), name=node.name) else: target = None else: pattern = None target = None body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) return ExceptClauseNode( self.loc(node), pattern=pattern, target=target, body=body, is_except_as=False, ) def visit_Try(self, node): body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) except_clauses = [self.visit(ec) for ec in node.handlers] if node.orelse: else_clause = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.orelse] ) else: else_clause = None try_except_stat_node = TryExceptStatNode( self.loc(node), body=body, except_clauses=except_clauses, else_clause=else_clause, ) # with `finally` statement or not if node.finalbody: final_clause = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.finalbody] ) return TryFinallyStatNode( self.loc(node), body=try_except_stat_node, finally_clause=final_clause ) return try_except_stat_node def visit_Assert(self, node): return AssertStatNode( self.loc(node), cond=self.visit(node.test), value=self.visit(node.msg) if node.msg else None, ) def visit_Delete(self, node): return DelStatNode( self.loc(node), args=[self.visit(target) for target in node.targets] ) def visit_Pass(self, node): return PassStatNode(self.loc(node)) def visit_If(self, node): condition = self.visit(node.test) body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) if node.orelse: else_body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.orelse] ) else: else_body = None return IfStatNode( self.loc(node), if_clauses=[IfClauseNode(self.loc(node), condition=condition, body=body)], else_clause=else_body, ) def visit_For(self, node): target_value = self.visit(node.target) iter_value = IteratorNode(self.loc(node.iter), sequence=self.visit(node.iter)) body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) if node.orelse: else_body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.orelse] ) else: else_body = None return ForInStatNode( self.loc(node), target=target_value, iterator=iter_value, body=body, else_clause=else_body, is_async=False, ) def visit_While(self, node): condition = self.visit(node.test) body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) if node.orelse: else_body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.orelse] ) else: else_body = None return WhileStatNode( self.loc(node), condition=condition, body=body, else_clause=else_body ) def visit_withitem(self, node): return self.visit(node.context_expr) def visit_With(self, node): # multiple items is not supported yet assert len(node.items) == 1 manager = self.visit(node.items[0]) target = self.visit(node.items[0].optional_vars) body = StatListNode( self.loc(node), stats=[self.visit(stat) for stat in node.body] ) return WithStatNode(self.loc(node), manager=manager, target=target, body=body) def visit_Break(self, node): return BreakStatNode(self.loc(node)) def visit_Continue(self, node): return ContinueStatNode(self.loc(node)) def visit_FunctionDef(self, node): def is_static_method(func): return ( func.decorator_list and isinstance(func.decorator_list[0], ast.Name) and (func.decorator_list[0].id == "staticmethod") ) if not is_static_method(node): raise RuntimeError("Missing decorator staticmethod.") function_name = node.name function_return_type = "void" if self._program_model == ProgramModel.PIE: # PIE program model if function_name == ExpectFuncDef.INIT.value: args = node.args.args assert len(args) == 2, "The number of parameters does not match" args = [ self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])), self.make_template_arg( "Context", [self._vd_type, self._md_type], args[1].arg, self.loc(args[1]), use_ref=True, ), ] elif function_name == ExpectFuncDef.PEVAL.value: args = node.args.args assert len(args) == 2, "The number of parameters does not match" args = [ self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])), self.make_template_arg( "Context", [self._vd_type, self._md_type], args[1].arg, self.loc(args[1]), use_ref=True, ), ] elif function_name == ExpectFuncDef.INCEVAL.value: args = node.args.args assert len(args) == 2, "The number of parameters does not match" args = [ self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])), self.make_template_arg( "Context", [self._vd_type, self._md_type], args[1].arg, self.loc(args[1]), use_ref=True, ), ] else: raise RuntimeError( "Not recognized method named {}".format(function_name) ) elif self._program_model == ProgramModel.Pregel: if function_name == ExpectFuncDef.INIT.value: args = node.args.args assert len(args) == 2, "The number of parameters does not match" args = [ self.make_template_arg( "Vertex", [self._vd_type, self._md_type], args[0].arg, self.loc(args[0]), use_ref=True, ), self.make_template_arg( "Context", [self._vd_type, self._md_type], args[1].arg, self.loc(args[1]), use_ref=True, ), ] elif function_name == ExpectFuncDef.COMPUTE.value: args = node.args.args assert len(args) == 3, "The number of parameters does not match" args = [ self.make_template_arg( "MessageIterator", [self._md_type], args[0].arg, self.loc(args[0]), ), self.make_template_arg( "Vertex", [self._vd_type, self._md_type], args[1].arg, self.loc(args[1]), use_ref=True, ), self.make_template_arg( "Context", [self._vd_type, self._md_type], args[2].arg, self.loc(args[2]), use_ref=True, ), ] elif function_name == ExpectFuncDef.COMBINE.value: args = node.args.args assert len(args) == 1, "The number of parameters does not match" args = [ self.make_template_arg( "MessageIterator", [self._md_type], args[0].arg, self.loc(args[0]), ) ] function_return_type = self._md_type else: raise RuntimeError( "Not recognized method named {}".format(function_name) ) base_type = CSimpleBaseTypeNode( self.loc(node), name=function_return_type, is_basic_c_type=1, signed=1, longness=0, is_self_arg=False, ) declarator_name = function_name declarator = CFuncDeclaratorNode( self.loc(node), base=CNameDeclaratorNode(self.loc(node), name=declarator_name), args=args, has_varargs=False, exception_value=None, exception_check=False, nogil=True, with_gil=False, overridable=False, ) # traverse body body = StatListNode( self.loc(node), stats=[self.visit(expr) for expr in node.body] ) return CFuncDefNode( self.loc(node), visibility="public", base_type=base_type, declarator=declarator, body=body, modifiers=[], api=False, overridable=False, is_const_method=False, ) def visit_Lambda(self, node): return LambdaNode( self.loc(node), args=[ self.make_plain_arg(arg.arg, self.loc(arg)) for arg in node.args.args ], star_arg=None, starstar_arg=None, retult_expr=self.visit(node.body), ) def visit_Return(self, node): if node.value is None: value = None else: value = self.visit(node.value) return ReturnStatNode(self.loc(node), value=value) def visit_Yield(self, node): return YieldExprNode(self.loc(node), expr=self.visit(node.value)) def visit_YieldFrom(self, node): return YieldFromExprNode(self.loc(node), expr=self.visit(node.value)) def visit_Global(self, node): return GlobalNode(self.loc(node), names=node.names) def visit_Nonlocal(self, node): return NonlocalNode(self.loc(node), names=node.names) def visit_Await(self, node): return AwaitExprNode(self.loc(node), expr=self.visit(node.value))