in fx/nnc_compile.py [0:0]
def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
"""
nnc_compile(model, example_inputs) returns a function with the same args
as `model.forward`, with an extra argument corresponding to where the
output is stored. This function takes the inputs (which must be PyTorch
tensors with the same shapes as example_inputs), and passes them to an
NNC executor.
"""
fx_model = fx.symbolic_trace(model)
ShapeProp(fx_model).propagate(*example_inputs)
# This env maps from nodes to `te.ExprHandle`, which represent the output
# of an NNC computation.
env = {}
def get_te_shapes(node):
return [te.ExprHandle.int(i) for i in node.shape]
def get_nnc_type(dtype):
if dtype == torch.float:
return te.Dtype.Float
elif dtype == torch.long:
return te.Dtype.Long
else:
raise RuntimeError("nyi")
def get_te_type(node):
return get_nnc_type(node.dtype)
def gen_compute(args):
te_args = [env[arg.name] for arg in args]
def lookup_env(l):
return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
def fetch_attr(target : str):
target_atoms = target.split('.')
attr_itr = fx_model
for i, atom in enumerate(target_atoms):
if not hasattr(attr_itr, atom):
raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
attr_itr = getattr(attr_itr, atom)
return attr_itr
outs = None
inputs = []
module_attrs = []
for node in fx_model.graph.nodes:
if node.op == 'placeholder':
# We simply map the input placeholder to a `te.Placeholder`, which
# also represents an input to the NNC computation.
shapes = get_te_shapes(node)
env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes)
inputs.append(env[node.name])
elif node.op == 'call_function':
# This does the bulk of the work - we call `lower_function`, which
# returns a `te.ExprHandle` (the output of a NNC computation), and
# put it in our environment.
result = lower_function(node, node.target, lookup_env(node.args), node.args)
env[node.name] = result
elif node.op == 'output':
outs = list(lookup_env(node.args))
elif node.op == 'get_attr':
# As NNC doesn't have any concept of state, we pull out the module
# attributes and pass them in as inputs to NNC.
module_attrs.append(node)
env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes)
else:
raise RuntimeError("not yet implemented")
loopnest = te.LoopNest(outs)
loopnest.prepare_for_codegen()
stmt = te.simplify(loopnest.root_stmt())
cg = te.construct_codegen('llvm', stmt, [te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs])
def f(inps):
module_stuff = [fetch_attr(i.target) for i in module_attrs]
cg.call(module_stuff + list(inps))
return f