in python/tvm/script/parser/core/evaluator.py [0:0]
def _visit(self, node: doc.AST) -> Any:
"""General doc AST node visiting method for expression evaluation.
Parameters
----------
node : doc.AST
The root node of AST tree node of expression to evaluate.
Returns
-------
res : Any
The evaluation result.
"""
args = []
if (
isinstance(node, doc.Call)
and hasattr(node.func, "attr")
and node.func.attr not in ["reads", "writes", "match_buffer", "realize"]
) or isinstance(node, (doc.BinOp, doc.UnaryOp, doc.Compare, doc.BoolOp)):
if isinstance(node, doc.BinOp):
args = [node.left, node.right]
elif isinstance(node, doc.UnaryOp):
args = [node.operand]
elif isinstance(node, doc.Compare):
args = [node.left, *node.comparators]
else:
if isinstance(node, doc.Call):
args = node.args
elif isinstance(node, doc.BoolOp):
args = node.values
for arg in args:
if isinstance(arg, doc.Subscript) and isinstance(arg.slice, (doc.Slice, doc.Tuple)):
if isinstance(arg.slice, doc.Slice):
check_slices = [arg.slice]
else:
check_slices = []
for p in arg.slice.elts:
if isinstance(p, doc.Slice):
check_slices.append(p)
for s in check_slices:
if not s.step and s.upper and s.lower:
s.step = doc.Constant(
1,
None,
1,
1,
s.upper.lineno,
s.upper.end_col_offset + 1,
s.upper.lineno,
s.upper.end_col_offset + 2,
)
if isinstance(node, list):
return [self._visit(n) for n in node]
if isinstance(node, tuple):
return tuple(self._visit(n) for n in node)
assert isinstance(node, doc.AST)
if isinstance(node, doc.Name):
if node.id not in self.value_table and not _get_builtin_or_none(node.id):
raise ParserError(node, f"Undefined variable: {node.id}")
return node
if isinstance(
node,
(
doc.Constant,
doc.expr_context,
doc.operator,
doc.boolop,
doc.unaryop,
doc.cmpop,
),
):
return node
if not isinstance(node, (doc.expr, doc.slice)):
return node
if isinstance(node, doc.Lambda):
return self._eval_lambda(node)
if isinstance(node, doc.Starred):
value = self._visit(node.value)
return doc.Starred(
value=value,
ctx=node.ctx,
lineno=node.lineno,
col_offset=node.col_offset,
end_lineno=node.end_lineno,
end_col_offset=node.end_col_offset,
)
fields = {}
for field in node.__class__._FIELDS: # pylint: disable=protected-access
attr = getattr(node, field)
if isinstance(attr, (doc.AST, tuple, list)):
fields[field] = self._visit(attr)
else:
fields[field] = attr
try:
if isinstance(node, doc.BoolOp):
value = self._eval_bool_op(fields)
elif isinstance(node, doc.Compare):
value = self._eval_compare(fields)
elif isinstance(node, doc.UnaryOp):
value = self._eval_unary_op(fields)
elif isinstance(node, doc.BinOp):
value = self._eval_bin_op(fields)
elif isinstance(node, doc.Slice):
value = self._eval_slice(fields)
else:
value = self._eval_expr(node.__class__(**fields))
except Exception as err: # pylint: disable=broad-except
self.parser.report_error(node, err)
return self._add_intermediate_result(value)