geneve/utils/ast_dag.py (119 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""AST drawing."""
import hashlib
import os
import random
import sys
from collections import namedtuple
from contextlib import contextmanager, nullcontext
import eql
import graphviz
Context = namedtuple("Context", ["graph", "colors"])
colors = ("red", "blue", "green", "orange", "darkorchid", "pink", "brown", "cyan", "purple")
random = random.Random()
class Digraph(graphviz.Digraph):
def _repr_mimebundle_(self, *args, **kwargs):
bundle = super()._repr_mimebundle_(*args, **kwargs)
bundle["text/plain"] = None
return bundle
def next_color(stack):
for c in colors:
if c not in stack:
return c
return stack[0]
@contextmanager
def new_color(ctx, attr):
new_color = next_color(ctx.colors)
ctx.colors.append(new_color)
ctx.graph.attr(attr, color=new_color, style="solid")
try:
yield
finally:
ctx.graph.attr(attr, color="black", style="dashed")
def get_node_id(label):
label = f"{label}-{random.random()}"
return hashlib.md5(label.encode("utf-8")).hexdigest()
def visit_ast(node, ctx, negate=False):
node_id = get_node_id(node.render())
if isinstance(node, eql.ast.Literal):
ctx.graph.node(node_id, node.render())
elif type(node) is eql.ast.Field:
ctx.graph.node(node_id, node.render())
elif type(node) is eql.ast.Or:
ctx.graph.node(node_id, "or")
for term in node.terms:
ctx.graph.attr("edge", color="black", style="solid")
with nullcontext() if negate else new_color(ctx, "edge"):
term_id = visit_ast(term, ctx, negate)
ctx.graph.edge(node_id, term_id)
elif type(node) is eql.ast.And:
ctx.graph.node(node_id, "and")
for term in node.terms:
ctx.graph.attr("edge", color="black", style="solid")
with new_color(ctx, "edge") if negate else nullcontext():
term_id = visit_ast(term, ctx, negate)
ctx.graph.edge(node_id, term_id)
elif type(node) is eql.ast.Not:
ctx.graph.node(node_id, "not")
term_id = visit_ast(node.term, ctx, not negate)
ctx.graph.edge(node_id, term_id)
elif type(node) is eql.ast.IsNull:
null_id = get_node_id("null")
ctx.graph.node(node_id, "==")
expr_id = visit_ast(node.expr, ctx, negate)
ctx.graph.node(null_id, "null")
ctx.graph.edge(node_id, expr_id)
ctx.graph.edge(node_id, null_id)
elif type(node) is eql.ast.IsNotNull:
null_id = get_node_id("null")
ctx.graph.node(node_id, "!=")
expr_id = visit_ast(node.expr, ctx, negate)
ctx.graph.node(null_id, "null")
ctx.graph.edge(node_id, expr_id)
ctx.graph.edge(node_id, null_id)
elif type(node) is eql.ast.InSet:
ctx.graph.node(node_id, "in")
ctx.graph.attr("edge", color="black", style="solid")
expr_id = visit_ast(node.expression, ctx, negate)
ctx.graph.edge(node_id, expr_id)
for term in node.container:
with nullcontext() if negate else new_color(ctx, "edge"):
term_id = visit_ast(term, ctx, negate)
ctx.graph.edge(node_id, term_id)
elif type(node) is eql.ast.Comparison:
ctx.graph.node(node_id, node.comparator)
left_id = visit_ast(node.left, ctx, negate)
right_id = visit_ast(node.right, ctx, negate)
ctx.graph.edge(node_id, left_id)
ctx.graph.edge(node_id, right_id)
elif type(node) is eql.ast.EventQuery:
visit_ast(node.query, ctx, negate)
elif type(node) is eql.ast.PipedQuery:
visit_ast(node.first, ctx, negate)
elif type(node) is eql.ast.FunctionCall:
ctx.graph.node(node_id, node.name.lower())
ctx.graph.attr("edge", color="black", style="solid")
arg_id = visit_ast(node.arguments[0], ctx, negate)
ctx.graph.edge(node_id, arg_id)
for arg in node.arguments[1:]:
with nullcontext() if negate else new_color(ctx, "edge"):
arg_id = visit_ast(arg, ctx, negate)
ctx.graph.edge(node_id, arg_id)
else:
raise ValueError(f"Unable to draw node type: {type(node)}")
return node_id
def draw_ast(ast, graph=None):
random.seed(ast.render())
if not graph:
graph = Digraph(format="svg")
visit_ast(ast, Context(graph, ["black"]))
return graph
def draw_query(query, filename):
ast = eql.parse_query(query)
name, ext = os.path.splitext(filename)
graph = graphviz.Digraph(comment=query, filename=name, format=ext[1:])
draw_ast(ast, graph)
graph.render(name)
if __name__ == "__main__":
if len(sys.argv) < 3:
sys.stderr.write(f"usage: {sys.argv[0]} <query> <filename>\n")
sys.exit(1)
draw_query(sys.argv[1], sys.argv[2])