torchbenchmark/util/classify_graphs.py (143 lines of code) (raw):
import argparse
import re
import torch
from enum import Enum
class OpType(Enum):
POINTWISE = 1
NORMS = 2
REDUCTIONS = 3
VIEWS_EXPANDS = 4
REMOVE = 5
IGNORE = 6
op_types = {
"aten::rsqrt": OpType.POINTWISE,
"aten::abs": OpType.POINTWISE,
"aten::eq": OpType.POINTWISE,
"aten::gelu": OpType.POINTWISE,
"aten::remainder": OpType.POINTWISE,
"aten::_softmax": OpType.POINTWISE,
"aten::clamp": OpType.POINTWISE,
"aten::gt": OpType.POINTWISE,
"aten::mul": OpType.POINTWISE,
"aten::add": OpType.POINTWISE,
"aten::sum": OpType.REDUCTIONS,
"aten::ne": OpType.POINTWISE,
"aten::silu": OpType.POINTWISE,
"aten::pow": OpType.POINTWISE,
"aten::ge": OpType.POINTWISE,
"aten::native_batch_norm": OpType.NORMS,
"aten::sub": OpType.POINTWISE,
"aten::mean": OpType.REDUCTIONS,
"aten::sqrt": OpType.POINTWISE,
"aten::reciprocal": OpType.POINTWISE,
"aten::reshape": OpType.VIEWS_EXPANDS,
"aten::relu": OpType.POINTWISE,
"prim::Constant": OpType.REMOVE,
"prim::TupleConstruct": OpType.IGNORE,
"aten::div": OpType.POINTWISE,
"aten::tanh": OpType.POINTWISE,
"aten::neg": OpType.POINTWISE,
"aten::log": OpType.POINTWISE,
"aten::unsqueeze": OpType.VIEWS_EXPANDS,
"aten::native_layer_norm": OpType.NORMS,
"aten::exp": OpType.POINTWISE,
"aten::sigmoid": OpType.POINTWISE,
}
def type_to_placeholder(op_type: OpType) -> str:
mapping = {
OpType.POINTWISE: "aten::pointwise_placeholder",
OpType.NORMS: "aten::norm_placeholder",
OpType.REDUCTIONS: "aten::reduction_placeholder",
OpType.VIEWS_EXPANDS: "aten::view_expand_placeholder",
OpType.IGNORE: "aten::ignore_placeholder",
OpType.REMOVE: "aten::remove_placeholder",
}
return mapping[op_type]
# get the op type. op_name is expected to be the qualified name.
def get_type(op_name: str) -> OpType:
if op_name in op_types:
return op_types[op_name]
for optype in OpType:
if type_to_placeholder(optype) == op_name:
return optype
raise NotImplementedError(f"No OpType known for op '{op_name}'")
def simplify_tensor_type(jit_type):
if isinstance(jit_type, torch._C.TensorType):
return torch._C.TensorType.get()
return jit_type
def remove_inputs(graph):
inputs_size = 0
for n in graph.inputs():
inputs_size += 1
for use in n.uses():
use.user.removeInput(use.offset)
for i in reversed(range(inputs_size)):
graph.eraseInput(i)
return graph
# Remove vertices like x or y below, where x or y are pointwise.
# (pointwise) --> (x) --> (...)
# (...) --> (y) --> (pointwise)
# if remove_all is true, then it doesn't care if pointwise ops preceed/succeed x or y.
def remove_duplicate_pointwise(graph, remove_all=False):
to_remove = []
old_str = str(graph)
def bypass_node(n):
to_remove.append(n)
n.output().replaceAllUsesWith(n.input())
for n in graph.nodes():
if get_type(n.kind()) != OpType.POINTWISE:
continue
if n.inputsSize() != 1 or n.outputsSize() != 1:
continue
if get_type(n.input().node().kind()) == OpType.POINTWISE or remove_all:
bypass_node(n)
continue
uses = [r.user for r in n.output().uses() if r.user.kind() != "prim::Return"]
if len(uses) >= 1 and (all(get_type(r.kind()) == OpType.POINTWISE for r in uses) or remove_all):
bypass_node(n)
continue
for n in reversed(to_remove):
n.destroy()
return graph
def compress_graph(graph):
old_nodes = []
erased_nodes = set()
for n in graph.nodes():
simple_type = get_type(n.kind())
if simple_type == OpType.IGNORE:
continue
old_nodes.append(n)
if simple_type == OpType.REMOVE:
erased_nodes.add(n)
continue
new_node = graph.create(type_to_placeholder(simple_type), n.outputsSize())
new_node.insertBefore(n)
for inp in n.inputs():
if inp.node() not in erased_nodes:
new_node.addInput(inp)
for old_out, new_out in zip(n.outputs(), new_node.outputs()):
new_out.setType(simplify_tensor_type(old_out.type()))
old_out.replaceAllUsesWith(new_out)
for n in reversed(old_nodes):
n.destroy()
graph = remove_inputs(graph)
graph = remove_duplicate_pointwise(graph)
return torch._C._jit_pass_canonicalize(graph, False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="""
Collection of helper functions for eliminating duplicate subgraphs
Usage:
~~~
import classify_graphs
# some ir string called "ir"
graph = torch._C.parse_ir(ir)
# "hashes" the graph based on categories of ops (pointwise, reductions, views/expands, norms)
compressed_graph = classify_graphs.compress_graph(graph)
# do something with the compressed graph
~~~
Alternatively, call it and it will return one graph per hashed category
Usage:
python3 log_extract.py log.txt --output > log_result.py
python3 classify_graphs.py log_result.py > filtered_logs.py
""", formatter_class = argparse.RawDescriptionHelpFormatter)
parser.add_argument("filename", type=str, help="output from log_extract.py --help")
args = parser.parse_args()
with open(args.filename) as f:
arr = eval(f.read())
# see 73984
for i in range(len(arr)):
if len(re.findall(r'value=annotate\(List\[int', arr[i])) >= 1:
arr[i] = arr[0]
classified = {}
for ir in arr:
graph = torch._C.parse_ir(ir)
graph = compress_graph(graph)
graph_class = str(graph)
if graph_class not in classified:
classified[graph_class] = []
classified[graph_class].append(ir)
final_selection = []
for cl, graphs in classified.items():
# choose the longest graph of this type
s = sorted(graphs, key=lambda x: -len(str(x)))
final_selection.append(str(graphs[0]))
print('[' + ', '.join(f'"""{x}"""' for x in final_selection) + ']')