in functorch/_src/aot_autograd.py [0:0]
def partition_with_recompute_fwd_in_bwd(joint_module: fx.GraphModule, _joint_inputs):
"""
Partitions the joint graph such that the backward recomputes the forward.
Recomputing helps in trading off memory bandwidth with computation.
To create the fwd and bwd graph, we copy the joint graph, manually set the
outputs to just original forward or backward outputs. And then we run the
resulting graphs through dead code elimintation.
"""
try:
import networkx as nx
except ImportError:
raise RuntimeError("Need networkx installed to perform smart recomputation heuristics")
# draw_graph(joint_module, "joint.svg")
full_bw_graph = joint_module.graph
nx_graph = nx.DiGraph()
tangent_closure = set()
name_to_node = {}
for node in full_bw_graph.nodes:
name_to_node[node.name] = node
if node.op == 'placeholder' and "tangents" in node.target:
tangent_closure.add(node)
if node in tangent_closure:
for user in node.users:
tangent_closure.add(user)
pointwise_ops = [aten.add, aten.sub, aten.div, aten.atan2, aten.mul, aten.max, aten.min, aten.pow, aten.remainder, aten.fmod, aten.__and__, aten.__or__, aten.__xor__, aten.__lshift__, aten.__rshift__, aten.eq, aten.ne, aten.ge, aten.gt, aten.le, aten.lt, aten.abs, aten.bitwise_not, aten.ceil, aten.floor, aten.frac, aten.neg, aten.relu, aten.round, aten.silu, aten.trunc, aten.log, aten.log10, aten.log1p, aten.log2, aten.lgamma, aten.exp, aten.expm1, aten.erf, aten.erfc, aten.cos, aten.acos, aten.cosh, aten.sin, aten.asin, aten.sinh, aten.tan, aten.atan, aten.tanh, aten.atanh, aten.sqrt, aten.rsqrt, aten.reciprocal, aten.sigmoid, aten.softplus, aten.threshold, aten.threshold_backward, aten.clamp, aten.where, aten.lerp, aten.addcmul, aten.gelu, aten.gelu_backward] # noqa: E501
reduction_ops = [aten.softmax, aten._softmax, aten._softmax_backward_data, aten.sum, aten.mean, aten._grad_sum_to_size, aten.sum_to_size, aten.amax] # noqa: E501
misc_ops = [aten.to, aten.type_as, operator.getitem]
# not recomputed by default since these are kinda expensive/hard to fuse into
# norm_ops = [aten.instance_norm, aten._batch_norm_impl_index, aten.native_batch_norm, aten.batch_norm, aten._batch_norm_impl_index_backward, aten.native_layer_norm, aten.layer_norm, aten.native_layer_norm_backward] # noqa: E501
# Not used by default since NVFuser can't fuse view ops
# view_ops = [aten.expand, aten.clone, aten.transpose, aten.t, aten.view, aten._unsafe_view, aten.permute, aten.transpose, aten.t, aten._reshape_alias, aten.squeeze, aten.unsqueeze, aten.reshape, aten.cat, aten.slice, aten.split, aten.select, aten.repeat] # noqa: E501
unrecomputable_ops = [aten.mm, aten.convolution, aten.convolution_backward, aten.bmm, aten.addmm, aten.native_dropout, aten.rand_like, aten.randn_like, aten.upsample_bilinear2d] # noqa: E501
recomputable_ops = set(
pointwise_ops
+ reduction_ops
+ misc_ops
# + norm_ops
# + view_ops
)
# ops = set([i.target for i in joint_module.graph.nodes if i.op == 'call_function'])
# print(ops - recomputable_ops)
AGGRESSIVE_RECOMPUTATION = False
for node in full_bw_graph.nodes:
if node in tangent_closure:
nx_graph.add_edge(node.name+"_in", "sink", capacity=math.inf)
continue
is_input = False
if node.op == 'placeholder' and "primals" in node.target:
nx_graph.add_edge("source", node.name+"_in", capacity=math.inf)
is_input = True
if AGGRESSIVE_RECOMPUTATION:
if node.op == 'call_function' and node.target in unrecomputable_ops:
nx_graph.add_edge("source", node.name+"_in", capacity=math.inf)
else:
if node.op == 'call_function' and node.target not in recomputable_ops:
nx_graph.add_edge("source", node.name+"_in", capacity=math.inf)
if 'tensor_meta' not in node.meta:
weight = math.inf
else:
mem_sz = size_of(node.meta['tensor_meta'])
if is_input:
weight = mem_sz
else:
weight = mem_sz * 2
nx_graph.add_edge(node.name+"_in", node.name+"_out", capacity=weight)
for user in node.users:
nx_graph.add_edge(node.name+"_out", user.name+"_in", capacity=math.inf)
cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
reachable, non_reachable = partition
cutset = set()
for u, nbrs in ((n, nx_graph[n]) for n in reachable):
cutset.update((u, v) for v in nbrs if v in non_reachable)
cut_nodes = set()
for node_in, node_out in cutset:
assert node_in[:-3] == node_out[:-4]
node_name = node_in[:-3]
cut_nodes.add(node_name)
# print(len(cut_nodes), sorted(list(cut_nodes)))
saved_values = [name_to_node[node] for node in cut_nodes]
return _extract_fwd_bwd_modules(joint_module, saved_values)