def nnc_compile()

in fx/nnc_compile.py [0:0]


def nnc_compile(model: torch.nn.Module, example_inputs) -> torch.nn.Module:
    """
    nnc_compile(model, example_inputs) returns a function with the same args
    as `model.forward`, with an extra argument corresponding to where the
    output is stored. This function takes the inputs (which must be PyTorch
    tensors with the same shapes as example_inputs), and passes them to an
    NNC executor.
    """
    fx_model = fx.symbolic_trace(model)
    ShapeProp(fx_model).propagate(*example_inputs)

    # This env maps from nodes to `te.ExprHandle`, which represent the output
    # of an NNC computation.
    env = {}
    def get_te_shapes(node):
        return [te.ExprHandle.int(i) for i in node.shape]

    def get_nnc_type(dtype):
        if dtype == torch.float:
            return te.Dtype.Float
        elif dtype == torch.long:
            return te.Dtype.Long
        else:
            raise RuntimeError("nyi")

    def get_te_type(node):
        return get_nnc_type(node.dtype)

    def gen_compute(args):
        te_args = [env[arg.name] for arg in args]

    def lookup_env(l):
        return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)

    def fetch_attr(target : str):
        target_atoms = target.split('.')
        attr_itr = fx_model
        for i, atom in enumerate(target_atoms):
            if not hasattr(attr_itr, atom):
                raise RuntimeError(f"Node referenced nonexistant target {'.'.join(target_atoms[:i])}")
            attr_itr = getattr(attr_itr, atom)
        return attr_itr

    outs = None
    inputs = []
    module_attrs = []
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # We simply map the input placeholder to a `te.Placeholder`, which
            # also represents an input to the NNC computation.
            shapes = get_te_shapes(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes)
            inputs.append(env[node.name])
        elif node.op == 'call_function':
            # This does the bulk of the work - we call `lower_function`, which
            # returns a `te.ExprHandle` (the output of a NNC computation), and
            # put it in our environment.
            result = lower_function(node, node.target, lookup_env(node.args), node.args)
            env[node.name] = result
        elif node.op == 'output':
            outs = list(lookup_env(node.args))
        elif node.op == 'get_attr':
            # As NNC doesn't have any concept of state, we pull out the module
            # attributes and pass them in as inputs to NNC.
            module_attrs.append(node)
            env[node.name] = te.Placeholder(node.name, get_te_type(node), shapes)
        else:
            raise RuntimeError("not yet implemented")

    loopnest = te.LoopNest(outs)
    loopnest.prepare_for_codegen()
    stmt = te.simplify(loopnest.root_stmt())
    cg = te.construct_codegen('llvm', stmt, [te.BufferArg(x) for x in [env[i.name] for i in module_attrs] + inputs + outs])
    def f(inps):
        module_stuff = [fetch_attr(i.target) for i in module_attrs]
        cg.call(module_stuff + list(inps))
    return f