def vmap()

in fx/vmap.py [0:0]


def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module:
    """vmap
    Given a model with inputs, vmap will return a function that works on
    batched versions of those inputs. Which inputs will be batched is
    determined by in_axes. In addition, as vmap requires shape (actually
    rank) information, we will pass in example_args (example inputs for the
    original module).
    """
    in_axes = iter(in_axes)
    fx_model = fx.symbolic_trace(model)
    # Here we run a shape propagation pass in order to annotate the graph with shape information.
    ShapeProp(fx_model).propagate(*example_args)
    # As vmap rewrites the whole graph, it's easiest to create an entirely new
    # graph and append to that.
    new_graph: fx.Graph = fx.Graph()

    # We will create an environment to map the new nodes created to the
    # corresponding old nodes.
    def lookup_env(l):
        return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
    env = {}
    for node in fx_model.graph.nodes:
        if node.op == 'placeholder':
            # If the node is an input placeholder, we simply copy it over and
            # annotate it with the batch dimension from `in_axes`.
            new_node = new_graph.placeholder(node.name)
            new_node.bdim = next(in_axes)
            new_node.shape = node.shape
            env[node.name] = new_node
        elif node.op == 'output':
            new_graph.output(env[node.args[0].name])
        elif node.op == 'call_function':
            new_args = lookup_env(node.args)
            # If any of the inputs to the function has a new batch dimension,
            # we will need to use our batching rules. Otherwise, we will simply
            # copy the node over.
            if any([x.bdim is not None for x in new_args if isinstance(x, fx.Node)]):
                new_node = gen_batching_rule_function(node.target, *new_args)
            else:
                new_node = new_graph.node_copy(node, lambda x: env[x.name])
                new_node.bdim = None
            new_node.shape = node.shape
            env[node.name] = new_node
        else:
            raise RuntimeError("Not yet implemented")


    res = fx.GraphModule(fx_model, new_graph)
    print(res.code)
    res.graph.lint()
    return res