python/graphscope/analytical/udf/compile.py (1,014 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
#
# Copyright 2020 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""An almost-complete Python to Cython AST transformer, with injected
GRAPHSCOPE-specific translation.
Python AST nodes are translated to corresponding Cython AST nodes as
it is, except:
1. for top-level method, a Cython type annotation is attached to the
function signature, for example,
.. code:: python
@graphscope.analytical.udf.peval('sssp')
def PEval(frag, context):
...
will be translated as:
.. code:: cython
cdef public void IncEval(Fragment *frag, ComputeContext *context):
...
it will make Cython understand what we really want and generate proper
Cpp code.
2. for invokation on methods inside :code:`graphscope.analytical.udf.core`, we generate
proper special :code:`cdef` definitions, or proper Cpp invokations,
just like :code:`cython.declare`, for example,
.. code:: python
heap = graphscope.analytical.udf.heap((float, 'node'))
modified = lang.vector(bool, [False for _ in range(inner_vertices.size())])
will be translated as:
.. code:: cython
cdef priority_queue[pair[double, NodeT]] heap
cdef vector[bool] modified([False for _ in range(inner_vertices.size())])
Note that :code:`float` in Python is mapped to :code:`double` in
Cython (further in Cpp code).
More specifically, we define a series of placeholders in module
:code:`graphscope.analytical.udf.core`, which cannot be executed in pure python mode.
The :code:`graphscope.analytical.udf.xxx` decorators will translate those ordinary
*"assignment and call"* into a :code:`cdef` node in Cython AST.
"""
import ast
import copy
import functools
import inspect
import textwrap
import types
import warnings
from Cython.CodeWriter import CodeWriter
from Cython.Compiler import Builtin
from Cython.Compiler import StringEncoding
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.ModuleNode import *
from Cython.Compiler.Nodes import *
from graphscope.analytical.udf.patch import patch_cython_codewriter
from graphscope.analytical.udf.utils import CType
from graphscope.analytical.udf.utils import ExpectFuncDef
from graphscope.analytical.udf.utils import LinesWrapper
from graphscope.analytical.udf.utils import PregelAggregatorType
from graphscope.analytical.udf.utils import ProgramModel
from graphscope.framework.errors import check_argument
class GRAPECompiler(ast.NodeVisitor):
def __init__(self, name, vd_type, md_type, program_model=ProgramModel.Pregel):
"""
Args:
name: str. The name of class.
vd_type: str. The type of the data stored in vertex.
md_type: str. The type of the message.
program_model: ProgramModel. 'Pregel' or 'PIE'
"""
self._name = name
self._vd_type = vd_type
self._md_type = md_type
self._program_model = program_model
# store aggregate function indexed by name
self.__registered_aggregators = {}
self.__globals = {}
self.__func_params_name_list = []
self.__pyx_header = LinesWrapper()
def set_pregel_program_model(self):
self._program_model = ProgramModel.Pregel
def set_pie_program_model(self):
self._program_model = ProgramModel.PIE
def parse(self, source):
"""Parse source into cython module node object.
source: str
The source code may represent a statement or expression.
Raises:
RuntimeError: unsupported ast trans from python to cython.
"""
tree = ast.parse(textwrap.dedent(source))
# associate `parent` reference to every node
for node in ast.walk(tree):
for child in ast.iter_child_nodes(node):
setattr(child, "__parent__", node)
cyast = self.visit(tree)
return cyast
def run(self, func_or_ast, pyx_header):
self.__pyx_header = pyx_header
# we already has a AST: just run it
if isinstance(func_or_ast, ast.AST):
cyast = self.visit(func_or_ast)
else:
check_argument(isinstance(func_or_ast, types.FunctionType))
# ignore varargs and keywords
self.__func_params_name_list = inspect.getfullargspec(func_or_ast).args
self.__globals = func_or_ast.__globals__
cyast = self.parse(inspect.getsource(func_or_ast))
writer = patch_cython_codewriter(CodeWriter())
cycode = "\n".join(writer.write(cyast).lines)
return cycode
def compile(self, source):
"""Compile source into cython code."""
cyast = self.parse(source)
writer = patch_cython_codewriter(CodeWriter())
return "\n".join(writer.write(cyast).lines)
def make_plain_arg(self, name, arg_loc):
return CArgDeclNode(
arg_loc,
base_type=CSimpleBaseTypeNode(
arg_loc,
name=None,
is_basic_c_type=0,
signed=1,
longness=0,
is_self_arg=False,
),
declarator=CNameDeclaratorNode(arg_loc, name=name),
not_none=0,
or_none=0,
default=None,
annotation=None,
)
def make_value_arg(self, value_type, name, arg_loc):
return CArgDeclNode(
arg_loc,
base_type=CSimpleBaseTypeNode(
arg_loc,
name=value_type,
is_basic_c_type=0,
signed=1,
longness=0,
is_self_arg=False,
),
declarator=CNameDeclaratorNode(arg_loc, name=name),
not_none=0,
or_none=0,
default=None,
annotation=None,
)
def make_ptr_arg(self, ptr_type, name, arg_loc):
return CArgDeclNode(
arg_loc,
base_type=CSimpleBaseTypeNode(
arg_loc,
name=ptr_type,
is_basic_c_type=0,
signed=1,
longness=0,
is_self_arg=False,
),
declarator=CPtrDeclaratorNode(
arg_loc, base=CNameDeclaratorNode(arg_loc, name=name)
),
not_none=0,
or_none=0,
default=None,
annotation=None,
)
def make_ref_arg(self, ref_type, name, arg_loc):
return CArgDeclNode(
arg_loc,
base_type=CSimpleBaseTypeNode(
arg_loc,
name=ref_type,
is_basic_c_type=0,
signed=1,
longness=0,
complex=0,
is_self_arg=False,
templates=None,
),
declarator=CReferenceDeclaratorNode(
arg_loc, base=CNameDeclaratorNode(arg_loc, name=name)
),
not_none=0,
or_none=0,
default=None,
annotation=None,
)
def make_template_arg(
self, value_type, value_tpls, name, arg_loc, use_ptr=False, use_ref=False
):
def mk_tpl_arg(n):
return CComplexBaseTypeNode(
arg_loc,
base_type=CSimpleBaseTypeNode(
arg_loc,
name=n,
is_basic_c_type=0,
signed=1,
longness=0,
is_self_arg=False,
),
declarator=CNameDeclaratorNode(
arg_loc, name="", cname=None, default=None
),
)
tpl_type = TemplatedTypeNode(
arg_loc,
positional_args=[mk_tpl_arg(n) for n in value_tpls],
keyword_args=DictNode(arg_loc, key_value_pairs=[]),
base_type_node=CSimpleBaseTypeNode(
arg_loc,
name=value_type,
is_basic_c_type=0,
signed=1,
longness=0,
is_self_arg=False,
),
)
if use_ptr:
declarator = CPtrDeclaratorNode(
arg_loc, base=CNameDeclaratorNode(arg_loc, name=name)
)
elif use_ref:
declarator = CReferenceDeclaratorNode(
arg_loc, base=CNameDeclaratorNode(arg_loc, name=name)
)
else:
declarator = CNameDeclaratorNode(arg_loc, name=name)
return CArgDeclNode(
arg_loc,
base_type=tpl_type,
declarator=declarator,
not_none=0,
or_none=0,
default=None,
annotation=None,
)
def loc(self, node):
return ["", 0, 0]
def generic_visit(self, node):
raise NotImplementedError("AST node %s is not supported yet" % node)
def visit_Module(self, node):
body = self.visit(node.body[0])
return ModuleNode(self.loc(node), body=body)
def visit_ImportFrom(self, node):
raise RuntimeError("ImportFrom is not supported yet.")
def visit_Import(self, node):
raise RuntimeError("Import is not supported yet.")
def visit_ClassDef(self, node):
raise RuntimeError("Class definition is not supported yet.")
def visit_JoinedStr(self, node):
raise RuntimeError("Joinedstr is not supported yet.")
def visit_Constant(self, node):
if isinstance(node.value, int):
return IntNode(self.loc(node), value=str(node.value))
if isinstance(node.value, float):
# We won't have c float, we map floating types to double
return FloatNode(self.loc(node), value=str(node.value))
if isinstance(node.value, str):
if node.kind == "u":
return UnicodeNode(self.loc(node), value=node.value, bytes_value=None)
return StringNode(
self.loc(node),
value=node.value,
unicode_value=StringEncoding.EncodedString(node.value),
)
if (
isinstance(Ellipsis, type)
and isinstance(node.value, Ellipsis)
or isinstance(node.value, type(Ellipsis))
):
return EllipsisNode(self.loc(node))
if isinstance(node.value, bytes):
return BytesNode(self.loc(node), value=node.s)
if node.value is None:
return NoneNode(self.loc(node))
raise NotImplementedError("Unknown constant value: %s" % node)
def visit_Num(self, node):
if isinstance(node.n, int):
return IntNode(self.loc(node), value=str(node.n))
if isinstance(node.n, float):
return FloatNode(self.loc(node), value=str(node.n))
if isinstance(node.n, complex):
raise NotImplementedError("Not support complex constant yet")
raise NotImplementedError("Unknown constant value: %s" % node)
def visit_Str(self, node):
return StringNode(
self.loc(node),
value=node.s,
unicode_value=StringEncoding.EncodedString(node.s),
)
def visit_Bytes(self, node):
return BytesNode(self.loc(node), value=node.s)
def visit_List(self, node):
return ListNode(self.loc(node), args=[self.visit(elt) for elt in node.elts])
def visit_Tuple(self, node):
return TupleNode(self.loc(node), args=[self.visit(elt) for elt in node.elts])
def visit_Set(self, node):
return SetNode(self.loc(node), args=[self.visit(elt) for elt in node.elts])
def visit_Dict(self, node):
kvs = [
DictItemNode(self.loc(node), key=self.visit(k), value=self.visit(v))
for k, v in zip(node.keys, node.values)
]
return DictNode(self.loc(node), key_value_pairs=kvs)
def visit_Ellipsis(self, node):
return EllipsisNode(self.loc(node))
def visit_NameConstant(self, node):
if node.value in [True, False]:
return BoolNode(self.loc(node), value=node.value)
return NoneNode(self.loc(node))
def visit_Name(self, node):
return NameNode(self.loc(node), name=node.id)
def visit_Expr(self, node):
expr = self.visit(node.value)
if isinstance(expr, CVarDefNode):
return expr
return ExprStatNode(self.loc(node), expr=expr)
def visit_UnaryOp(self, node):
if isinstance(node.op, ast.UAdd):
return UnaryPlusNode(
self.loc(node), operator="+", operand=self.visit(node.operand)
)
if isinstance(node.op, ast.USub):
return UnaryMinusNode(
self.loc(node), operator="-", operand=self.visit(node.operand)
)
if isinstance(node.op, ast.Not):
return NotNode(self.loc(node), operand=self.visit(node.operand))
if isinstance(node.op, ast.Invert):
return TildeNode(
self.loc(node), operator="~", operand=self.visit(node.operand)
)
def visit_UAdd(self, node):
return "+"
def visit_USub(self, node):
return "-"
def visit_Not(self, node):
return "not"
def visit_Invert(self, node):
return "invert"
def visit_BinOp(self, node):
lhs = self.visit(node.left)
rhs = self.visit(node.right)
op_mapping = {
ast.Add: (AddNode, "+"),
ast.Sub: (SubNode, "-"),
ast.Mult: (MulNode, "*"),
ast.Div: (DivNode, "/"),
ast.FloorDiv: (DivNode, "//"),
ast.Mod: (ModNode, "%"),
ast.Pow: (PowNode, "**"),
ast.MatMult: (MatMultNode, "@"),
ast.LShift: (IntBinopNode, "<<"),
ast.RShift: (IntBinopNode, ">>"),
ast.BitOr: (IntBinopNode, "|"),
ast.BitXor: (IntBinopNode, "^"),
ast.BitAnd: (IntBinopNode, "&"),
}
op_type, op = op_mapping[type(node.op)]
return op_type(self.loc(op), operator=op, operand1=lhs, operand2=rhs)
def visit_Add(self, node):
return "+"
def visit_Sub(self, node):
return "-"
def visit_Mult(self, node):
return "*"
def visit_Div(self, node):
return "/"
def visit_FloorDiv(self, node):
return "//"
def visit_Mod(self, node):
return "%"
def visit_Pow(self, node):
return "**"
def visit_LShift(self, node):
return "<<"
def visit_RShift(self, node):
return ">>"
def visit_BitOr(self, node):
return "|"
def visit_BitXor(self, node):
return "^"
def visit_BitAnd(self, node):
return "&"
def visit_MatMult(self, node):
return "@"
def visit_AnnAssign(self, node):
annotation = NameNode(self.loc(node), name=node.annotation.id)
lhs = NameNode(self.loc(node), name=node.target.id, annotation=annotation)
rhs = self.visit(node.value)
return SingleAssignmentNode(self.loc(node), lhs=lhs, rhs=rhs)
def visit_BoolOp(self, node):
return BoolBinopNode(
self.loc(node),
operator=self.visit(node.op),
operand1=self.visit(node.values[0]),
operand2=self.visit(node.values[1]),
)
def visit_And(self, node):
return "and"
def visit_Or(self, node):
return "or"
def visit_Compare(self, node):
operator = self.visit(node.ops[0])
operand1 = self.visit(node.left)
operand2 = self.visit(node.comparators[0])
if len(node.comparators) == 1:
# single comparison
return PrimaryCmpNode(
self.loc(node), operator=operator, operand1=operand1, operand2=operand2
)
# multiple continuous comparison
cascade_node = CascadedCmpNode(
self.loc(node),
operator=self.visit(node.ops[-1]),
operand2=self.visit(node.comparators[-1]),
)
for op, comparator in zip(node.ops[-2:0:-1], node.comparators[-2:0:-1]):
cascade_node = CascadedCmpNode(
self.loc(node),
operator=self.visit(op),
operand2=self.visit(comparator),
cascade=cascade_node,
)
return PrimaryCmpNode(
self.loc(node),
operator=operator,
operand1=operand1,
operand2=operand2,
cascade=cascade_node,
)
def visit_Eq(self, node):
return "=="
def visit_NotEq(self, node):
return "!="
def visit_Lt(self, node):
return "<"
def visit_LtE(self, node):
return "<="
def visit_Gt(self, node):
return ">"
def visit_GtE(self, node):
return ">="
def visit_Is(self, node):
return "is"
def visit_IsNot(self, node):
return "is not"
def visit_In(self, node):
return "in"
def visit_NotIn(self, node):
return "not in"
def __flatten_func_name(self, name):
if isinstance(name, ast.Name):
return [name.id]
if isinstance(name, ast.Attribute):
return self.__flatten_func_name(name.value) + [name.attr]
return []
def __is_graphscope_api_call(self, node):
flat_func_name = self.__flatten_func_name(node.func)
if len(flat_func_name) == 0:
return False
if flat_func_name[0] in self.__func_params_name_list:
return True
# check from graphscope module
cascade = self.__globals.get(flat_func_name[0])
if cascade is None:
return False
for n in flat_func_name[1:]:
if cascade is None or not hasattr(cascade, n):
return False
cascade = getattr(cascade, n)
return cascade.__module__ == "graphscope.analytical.udf.types"
def __visit_GraphScopeAPICall(self, node):
full_func_name = self.__flatten_func_name(node.func)
obj = full_func_name[0]
name = node.func.attr
if obj == "graphscope":
# graphscope.declare()
if name == "declare":
var = node.args[1].id
var_type = node.args[0].attr
return CVarDefNode(
self.loc(node),
base_type=CSimpleBaseTypeNode(
self.loc(node),
name=var_type,
module_path=[],
is_basic_c_type=0,
signed=1,
),
declarators=[CNameDeclaratorNode(self.loc(node), name=var)],
visibility="private",
)
elif obj == "context" and name == "register_aggregator":
# context.register_aggregator()
args = node.args
if len(args) != 2:
raise ValueError("Params within register_aggregator incorrect.")
if (
isinstance(args[1], ast.Attribute)
and args[1].value.id == "PregelAggregatorType"
):
self.__registered_aggregators[str(args[0].s)] = args[1].attr
return SimpleCallNode(
self.loc(node),
function=self.visit(node.func),
args=[self.visit(arg) for arg in node.args],
)
elif obj == "context" and name == "aggregate":
# context.aggregate()
args = node.args
if len(args) != 2:
raise ValueError("Params within aggregate incorrect.")
if str(args[0].s) not in self.__registered_aggregators.keys():
raise KeyError(
"Aggregator %s not exist, you may want to register first."
% str(args[0].s)
)
ctype = PregelAggregatorType.extract_ctype(
self.__registered_aggregators[str(args[0].s)]
)
return SimpleCallNode(
self.loc(node),
function=IndexNode(
self.loc(node),
base=AttributeNode(
self.loc(node),
obj=NameNode(self.loc(node), name=obj),
attribute=name,
),
index=NameNode(self.loc(node), name=str(ctype)),
),
args=[self.visit(arg) for arg in node.args],
)
elif obj == "context" and name == "get_aggregated_value":
# context.get_aggregated_value()
args = node.args
if len(args) != 1:
raise ValueError("Params within get_aggregated_value incorrect.")
if str(args[0].s) not in self.__registered_aggregators.keys():
raise KeyError(
"Aggregator %s not exist, you may want to register first."
% str(args[0].s)
)
ctype = PregelAggregatorType.extract_ctype(
self.__registered_aggregators[str(args[0].s)]
)
return SimpleCallNode(
self.loc(node),
function=IndexNode(
self.loc(node),
base=AttributeNode(
self.loc(node),
obj=NameNode(self.loc(node), name=obj),
attribute=name,
),
index=NameNode(self.loc(node), name=str(ctype)),
),
args=[self.visit(arg) for arg in node.args],
)
elif obj == "context" and full_func_name[1] == "math":
mnode = copy.copy(node)
mnode.func = ast.Attribute(value=ast.Name(id="math"), attr=name)
return self.visit(mnode)
else:
return SimpleCallNode(
self.loc(node),
function=self.visit(node.func),
args=[self.visit(arg) for arg in node.args],
)
def visit_Call(self, node):
if self.__is_graphscope_api_call(node):
return self.__visit_GraphScopeAPICall(node)
if not node.keywords:
return SimpleCallNode(
self.loc(node),
function=self.visit(node.func),
args=[self.visit(arg) for arg in node.args],
)
# with kwargs param
return GeneralCallNode(
self.loc(node),
function=self.visit(node.func),
positional_args=TupleNode(
self.loc(node), args=[self.visit(arg) for arg in node.args]
),
# keyword_args=DictNode(self.loc(node), key_value_pairs=[]))
keyword_args=self._visit_keywords(node.keywords),
)
def _visit_keywords(self, node):
kvs = []
for keyword in node:
kvs.append(self.visit_keyword(keyword))
return DictNode(self.loc(node), key_value_pairs=kvs, reject_duplicates=True)
def visit_keyword(self, node):
key = IdentifierStringNode(self.loc(node), value=node.arg)
return DictItemNode(self.loc(node), key=key, value=self.visit(node.value))
def visit_IfExp(self, node):
return CondExprNode(
self.loc(node),
test=self.visit(node.test),
true_val=self.visit(node.body),
false_val=self.visit(node.orelse),
)
def visit_Attribute(self, node):
full_attr_name = self.__flatten_func_name(node)
if full_attr_name[0] == "context" and full_attr_name[1] == "math":
mnode = copy.copy(node)
mnode.value = ast.Name(id="math")
return self.visit(mnode)
return AttributeNode(
self.loc(node), obj=self.visit(node.value), attribute=node.attr
)
def visit_Subscript(self, node):
return IndexNode(
self.loc(node), base=self.visit(node.value), index=self.visit(node.slice)
)
def visit_Index(self, node):
return self.visit(node.value)
def visit_Slice(self, node):
start = (
NoneNode(self.loc(node)) if node.lower is None else self.visit(node.lower)
)
stop = (
NoneNode(self.loc(node)) if node.upper is None else self.visit(node.upper)
)
step = NoneNode(self.loc(node)) if node.step is None else self.visit(node.step)
return SliceNode(self.loc(node), start=start, stop=stop, step=step)
def visit_ExtSlice(self, node):
return TupleNode(self.loc(node), args=[self.visit(dim) for dim in node.dims])
def visit_ListComp(self, node):
check_argument(len(node.generators) == 1)
# has if node or not
has_if = True if node.generators[0].ifs else False
expression_value = self.visit(node.elt)
generator = node.generators[0]
iter_value = IteratorNode(
self.loc(generator.iter), sequence=self.visit(generator.iter)
)
comp_node = ComprehensionAppendNode(self.loc(generator), expr=expression_value)
if has_if:
check_argument(len(node.generators[0].ifs) == 1)
# construct IfStatNode
condition = self.visit(node.generators[0].ifs[0])
body = comp_node
if_stat_node = IfStatNode(
self.loc(node),
if_clauses=[
IfClauseNode(self.loc(node), condition=condition, body=body)
],
else_clause=None,
)
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=if_stat_node,
else_clause=None,
is_async=False,
)
else:
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=comp_node,
else_clause=None,
is_async=False,
)
return ComprehensionNode(
self.loc(node),
loop=loop,
append=comp_node,
type=Builtin.list_type,
has_local_scope=True,
)
def visit_SetComp(self, node):
assert len(node.generators) == 1
# has if node or not
has_if = True if node.generators[0].ifs else False
expression_value = self.visit(node.elt)
generator = node.generators[0]
iter_value = IteratorNode(
self.loc(generator.iter), sequence=self.visit(generator.iter)
)
comp_node = ComprehensionAppendNode(self.loc(generator), expr=expression_value)
if has_if:
assert len(node.generators[0].ifs) == 1
# construct IfStatNode
condition = self.visit(node.generators[0].ifs[0])
body = comp_node
if_stat_node = IfStatNode(
self.loc(node),
if_clauses=[
IfClauseNode(self.loc(node), condition=condition, body=body)
],
else_clause=None,
)
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=if_stat_node,
else_clause=None,
is_async=False,
)
else:
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=comp_node,
else_clause=None,
is_async=False,
)
return ComprehensionNode(
self.loc(node), loop=loop, append=comp_node, type=Builtin.set_type
)
def visit_DictComp(self, node):
assert len(node.generators) == 1
# has if node or not
has_if = True if node.generators[0].ifs else False
generator = node.generators[0]
iter_value = IteratorNode(
self.loc(generator.iter), sequence=self.visit(generator.iter)
)
comp_node = DictComprehensionAppendNode(
self.loc(generator),
key_expr=self.visit(node.key),
value_expr=self.visit(node.value),
)
if has_if:
assert len(node.generators[0].ifs) == 1
# construct IfStatNode
condition = self.visit(node.generators[0].ifs[0])
body = comp_node
if_stat_node = IfStatNode(
self.loc(node),
if_clauses=[
IfClauseNode(self.loc(node), condition=condition, body=body)
],
else_clause=None,
)
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=if_stat_node,
else_clause=None,
is_async=False,
)
else:
loop = ForInStatNode(
self.loc(node),
target=self.visit(generator.target),
iterator=iter_value,
body=comp_node,
else_clause=None,
is_async=False,
)
return ComprehensionNode(
self.loc(node), loop=loop, append=comp_node, type=Builtin.dict_type
)
def visit_Assign(self, node):
# `tuple` represents a multiple assign
assert len(node.targets) == 1
if (
hasattr(node.targets[0], "id")
and node.targets[0].id in self.__func_params_name_list
):
raise RuntimeError("Can't assign to internal variables.")
lhs = self.visit(node.targets[0])
rhs = self.visit(node.value)
return SingleAssignmentNode(self.loc(node), lhs=lhs, rhs=rhs)
def visit_AugAssign(self, node):
return InPlaceAssignmentNode(
self.loc(node),
operator=self.visit(node.op),
lhs=self.visit(node.target),
rhs=self.visit(node.value),
)
def visit_Raise(self, node):
return RaiseStatNode(
self.loc(node),
exc_type=self.visit(node.exc),
exc_value=None,
exc_tb=None,
cause=None if node.cause is None else self.visit(node.cause),
)
def visit_ExceptHandler(self, node):
if node.type:
pattern = [self.visit(node.type)]
if node.name:
target = NameNode(self.loc(node), name=node.name)
else:
target = None
else:
pattern = None
target = None
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
return ExceptClauseNode(
self.loc(node),
pattern=pattern,
target=target,
body=body,
is_except_as=False,
)
def visit_Try(self, node):
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
except_clauses = [self.visit(ec) for ec in node.handlers]
if node.orelse:
else_clause = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.orelse]
)
else:
else_clause = None
try_except_stat_node = TryExceptStatNode(
self.loc(node),
body=body,
except_clauses=except_clauses,
else_clause=else_clause,
)
# with `finally` statement or not
if node.finalbody:
final_clause = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.finalbody]
)
return TryFinallyStatNode(
self.loc(node), body=try_except_stat_node, finally_clause=final_clause
)
return try_except_stat_node
def visit_Assert(self, node):
return AssertStatNode(
self.loc(node),
cond=self.visit(node.test),
value=self.visit(node.msg) if node.msg else None,
)
def visit_Delete(self, node):
return DelStatNode(
self.loc(node), args=[self.visit(target) for target in node.targets]
)
def visit_Pass(self, node):
return PassStatNode(self.loc(node))
def visit_If(self, node):
condition = self.visit(node.test)
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
if node.orelse:
else_body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.orelse]
)
else:
else_body = None
return IfStatNode(
self.loc(node),
if_clauses=[IfClauseNode(self.loc(node), condition=condition, body=body)],
else_clause=else_body,
)
def visit_For(self, node):
target_value = self.visit(node.target)
iter_value = IteratorNode(self.loc(node.iter), sequence=self.visit(node.iter))
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
if node.orelse:
else_body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.orelse]
)
else:
else_body = None
return ForInStatNode(
self.loc(node),
target=target_value,
iterator=iter_value,
body=body,
else_clause=else_body,
is_async=False,
)
def visit_While(self, node):
condition = self.visit(node.test)
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
if node.orelse:
else_body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.orelse]
)
else:
else_body = None
return WhileStatNode(
self.loc(node), condition=condition, body=body, else_clause=else_body
)
def visit_withitem(self, node):
return self.visit(node.context_expr)
def visit_With(self, node):
# multiple items is not supported yet
assert len(node.items) == 1
manager = self.visit(node.items[0])
target = self.visit(node.items[0].optional_vars)
body = StatListNode(
self.loc(node), stats=[self.visit(stat) for stat in node.body]
)
return WithStatNode(self.loc(node), manager=manager, target=target, body=body)
def visit_Break(self, node):
return BreakStatNode(self.loc(node))
def visit_Continue(self, node):
return ContinueStatNode(self.loc(node))
def visit_FunctionDef(self, node):
def is_static_method(func):
return (
func.decorator_list
and isinstance(func.decorator_list[0], ast.Name)
and (func.decorator_list[0].id == "staticmethod")
)
if not is_static_method(node):
raise RuntimeError("Missing decorator staticmethod.")
function_name = node.name
function_return_type = "void"
if self._program_model == ProgramModel.PIE: # PIE program model
if function_name == ExpectFuncDef.INIT.value:
args = node.args.args
assert len(args) == 2, "The number of parameters does not match"
args = [
self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])),
self.make_template_arg(
"Context",
[self._vd_type, self._md_type],
args[1].arg,
self.loc(args[1]),
use_ref=True,
),
]
elif function_name == ExpectFuncDef.PEVAL.value:
args = node.args.args
assert len(args) == 2, "The number of parameters does not match"
args = [
self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])),
self.make_template_arg(
"Context",
[self._vd_type, self._md_type],
args[1].arg,
self.loc(args[1]),
use_ref=True,
),
]
elif function_name == ExpectFuncDef.INCEVAL.value:
args = node.args.args
assert len(args) == 2, "The number of parameters does not match"
args = [
self.make_ref_arg("Fragment", args[0].arg, self.loc(args[0])),
self.make_template_arg(
"Context",
[self._vd_type, self._md_type],
args[1].arg,
self.loc(args[1]),
use_ref=True,
),
]
else:
raise RuntimeError(
"Not recognized method named {}".format(function_name)
)
elif self._program_model == ProgramModel.Pregel:
if function_name == ExpectFuncDef.INIT.value:
args = node.args.args
assert len(args) == 2, "The number of parameters does not match"
args = [
self.make_template_arg(
"Vertex",
[self._vd_type, self._md_type],
args[0].arg,
self.loc(args[0]),
use_ref=True,
),
self.make_template_arg(
"Context",
[self._vd_type, self._md_type],
args[1].arg,
self.loc(args[1]),
use_ref=True,
),
]
elif function_name == ExpectFuncDef.COMPUTE.value:
args = node.args.args
assert len(args) == 3, "The number of parameters does not match"
args = [
self.make_template_arg(
"MessageIterator",
[self._md_type],
args[0].arg,
self.loc(args[0]),
),
self.make_template_arg(
"Vertex",
[self._vd_type, self._md_type],
args[1].arg,
self.loc(args[1]),
use_ref=True,
),
self.make_template_arg(
"Context",
[self._vd_type, self._md_type],
args[2].arg,
self.loc(args[2]),
use_ref=True,
),
]
elif function_name == ExpectFuncDef.COMBINE.value:
args = node.args.args
assert len(args) == 1, "The number of parameters does not match"
args = [
self.make_template_arg(
"MessageIterator",
[self._md_type],
args[0].arg,
self.loc(args[0]),
)
]
function_return_type = self._md_type
else:
raise RuntimeError(
"Not recognized method named {}".format(function_name)
)
base_type = CSimpleBaseTypeNode(
self.loc(node),
name=function_return_type,
is_basic_c_type=1,
signed=1,
longness=0,
is_self_arg=False,
)
declarator_name = function_name
declarator = CFuncDeclaratorNode(
self.loc(node),
base=CNameDeclaratorNode(self.loc(node), name=declarator_name),
args=args,
has_varargs=False,
exception_value=None,
exception_check=False,
nogil=True,
with_gil=False,
overridable=False,
)
# traverse body
body = StatListNode(
self.loc(node), stats=[self.visit(expr) for expr in node.body]
)
return CFuncDefNode(
self.loc(node),
visibility="public",
base_type=base_type,
declarator=declarator,
body=body,
modifiers=[],
api=False,
overridable=False,
is_const_method=False,
)
def visit_Lambda(self, node):
return LambdaNode(
self.loc(node),
args=[
self.make_plain_arg(arg.arg, self.loc(arg)) for arg in node.args.args
],
star_arg=None,
starstar_arg=None,
retult_expr=self.visit(node.body),
)
def visit_Return(self, node):
if node.value is None:
value = None
else:
value = self.visit(node.value)
return ReturnStatNode(self.loc(node), value=value)
def visit_Yield(self, node):
return YieldExprNode(self.loc(node), expr=self.visit(node.value))
def visit_YieldFrom(self, node):
return YieldFromExprNode(self.loc(node), expr=self.visit(node.value))
def visit_Global(self, node):
return GlobalNode(self.loc(node), names=node.names)
def visit_Nonlocal(self, node):
return NonlocalNode(self.loc(node), names=node.names)
def visit_Await(self, node):
return AwaitExprNode(self.loc(node), expr=self.visit(node.value))