in fx/vmap.py [0:0]
def vmap(model: torch.nn.Module, in_axes: Tuple[Optional[int], ...], example_args: Tuple[Any, ...]) -> torch.nn.Module:
"""vmap
Given a model with inputs, vmap will return a function that works on
batched versions of those inputs. Which inputs will be batched is
determined by in_axes. In addition, as vmap requires shape (actually
rank) information, we will pass in example_args (example inputs for the
original module).
"""
in_axes = iter(in_axes)
fx_model = fx.symbolic_trace(model)
# Here we run a shape propagation pass in order to annotate the graph with shape information.
ShapeProp(fx_model).propagate(*example_args)
# As vmap rewrites the whole graph, it's easiest to create an entirely new
# graph and append to that.
new_graph: fx.Graph = fx.Graph()
# We will create an environment to map the new nodes created to the
# corresponding old nodes.
def lookup_env(l):
return fx.node.map_aggregate(l, lambda x: env[x.name] if isinstance(x, fx.Node) else x)
env = {}
for node in fx_model.graph.nodes:
if node.op == 'placeholder':
# If the node is an input placeholder, we simply copy it over and
# annotate it with the batch dimension from `in_axes`.
new_node = new_graph.placeholder(node.name)
new_node.bdim = next(in_axes)
new_node.shape = node.shape
env[node.name] = new_node
elif node.op == 'output':
new_graph.output(env[node.args[0].name])
elif node.op == 'call_function':
new_args = lookup_env(node.args)
# If any of the inputs to the function has a new batch dimension,
# we will need to use our batching rules. Otherwise, we will simply
# copy the node over.
if any([x.bdim is not None for x in new_args if isinstance(x, fx.Node)]):
new_node = gen_batching_rule_function(node.target, *new_args)
else:
new_node = new_graph.node_copy(node, lambda x: env[x.name])
new_node.bdim = None
new_node.shape = node.shape
env[node.name] = new_node
else:
raise RuntimeError("Not yet implemented")
res = fx.GraphModule(fx_model, new_graph)
print(res.code)
res.graph.lint()
return res