def get_nodes_and_edges()

in amazon_comprehend_events_tutorial/notebooks/events_graph.py [0:0]


def get_nodes_and_edges(
    result, node_types=['event', 'trigger', 'entity_group', 'entity'], thr=0.0
    ):
    """Convert results to (nodelist, edgelist) depending on specified entity types."""
    nodes = []
    edges = []
    event_nodes = []
    entity_nodes = []  
    entity_group_nodes = [] 
    trigger_nodes = []
    
    # Nodes are (id, type, tag, score, mention_type) tuples.
    if 'event' in node_types:
        event_nodes = [
            (
                "ev%d" % i,
                 t['Type'],
                 t['Type'],
                 t['Score'],
                 "event"
            )
            for i, e in enumerate(result['Events'])
            for t in e['Triggers'][:1]
            if t['GroupScore'] > thr
        ]
        nodes.extend(event_nodes)
    
    if 'trigger' in node_types:
        trigger_nodes = [
            (
                "ev%d-tr%d" % (i, j),
                t['Type'],
                t['Text'],
                t['Score'],
                "trigger"
            )
            for i, e in enumerate(result['Events'])
            for j, t in enumerate(e['Triggers'])
            if t['Score'] > thr
        ]
        trigger_nodes = list({t[1:3]: t for t in trigger_nodes}.values())
        nodes.extend(trigger_nodes)
        
    if 'entity_group' in node_types:
        entity_group_nodes = [
            (
                "gr%d" % i,
                m['Type'],
                m['Text'] if 'entity' not in node_types else m['Type'],
                m['Score'],
                "entity_group"
            )
            for i, e in enumerate(result['Entities'])
            for m in get_canonical_mention(e['Mentions'])
            if m['GroupScore'] > thr
        ]
        nodes.extend(entity_group_nodes)
        
    if 'entity' in node_types:
        entity_nodes = [
            (
                "gr%d-en%d" % (i, j),
                m['Type'],
                m['Text'],
                m['Score'],
                "entity"
            )
            for i, e in enumerate(result['Entities'])
            for j, m in enumerate(e['Mentions'])
            if m['Score'] > thr
        ]
        entity_nodes = list({t[1:3]: t for t in entity_nodes}.values())
        nodes.extend(entity_nodes)

    # Edges are (trigger_id, node_id, role, score, type) tuples.
    if event_nodes and entity_group_nodes:
        edges.extend([
            ("ev%d" % i, "gr%d" % a['EntityIndex'], a['Role'], a['Score'], "argument")
            for i, e in enumerate(result['Events'])
            for j, a in enumerate(e['Arguments'])
            #if a['Score'] > THR
        ])
    
    if entity_nodes and entity_group_nodes:
        entity_keys = set([n[0] for n in entity_nodes])
        edges.extend([
            ("gr%d" % i, "gr%d-en%d" % (i, j), "", m['GroupScore'], "coref")
            for i, e in enumerate(result['Entities'])
            for j, m in enumerate(e['Mentions'])
            if "gr%d-en%d" % (i, j) in entity_keys
            if m['GroupScore'] > thr
        ])

    if event_nodes and trigger_nodes:
        trigger_keys = set([n[0] for n in trigger_nodes])
        edges.extend([
            ("ev%d" % i, "ev%d-tr%d" % (i, j), "", a['GroupScore'], "coref")
            for i, e in enumerate(result['Events'])
            for j, a in enumerate(e['Triggers'])
            if "ev%d-tr%d" % (i, j) in trigger_keys
            if a['GroupScore'] > thr
        ])
        
    return nodes, edges