in lucid/misc/graph_analysis/overlay_graph.py [0:0]
def __init__(self, tf_graph, nodes=None, no_pass_through=None, prev_overlay=None):
self.tf_graph = tf_graph
if nodes is None:
nodes = []
for op in tf_graph.get_operations():
nodes.extend([out.name for out in op.outputs])
self.name_map = {name: OverlayNode(name, self) for name in nodes}
self.nodes = [self.name_map[name] for name in nodes]
self.no_pass_through = [] if no_pass_through is None else no_pass_through
self.node_to_consumers = defaultdict(set)
self.node_to_inputs = defaultdict(set)
self.prev_overlay = prev_overlay
for node in self.nodes:
for inp in self._get_overlay_inputs(node):
self.node_to_inputs[node].add(inp)
self.node_to_consumers[inp].add(node)
self.node_to_extended_consumers = defaultdict(set)
self.node_to_extended_inputs = defaultdict(set)
for node in self.nodes:
for inp in self.node_to_inputs[node]:
self.node_to_extended_inputs[node].add(inp)
self.node_to_extended_inputs[node].update(self.node_to_extended_inputs[inp])
for node in self.nodes[::-1]:
for out in self.node_to_consumers[node]:
self.node_to_extended_consumers[node].add(out)
self.node_to_extended_consumers[node].update(self.node_to_extended_consumers[out])