in nestedtensor/nested/fuser.py [0:0]
def fuse_conv_relu(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
"""
Fuses convolution/BN layers for inference purposes. Will deepcopy your
model by default, but can modify the model inplace as well.
"""
patterns = [(torch.nn.Conv2d, torch.nn.ReLU)]
if not inplace:
model = copy.deepcopy(model)
fx_model = fx.symbolic_trace(model)
modules = dict(fx_model.named_modules())
new_graph = copy.deepcopy(fx_model.graph)
for pattern in patterns:
for node in new_graph.nodes:
if matches_module_pattern(pattern, node, modules):
if len(node.args[0].users) > 1: # Output of conv is used by other nodes
continue
conv = modules[node.args[0].target]
relu = modules[node.target]
fused_conv = Conv2dReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
replace_node_module(node.args[0], modules, fused_conv)
node.replace_all_uses_with(node.args[0])
new_graph.erase_node(node)
last_nodes = []
count = 0
for node in new_graph.nodes:
if count == 31:
break
if (node.op == "call_function" or node.op == "call_module"):
last_nodes.append(node)
if len(last_nodes) == 4:
last_nodes = last_nodes[1:]
if len(last_nodes) < 3:
continue
is_match = True
is_match = is_match and (last_nodes[0].op == "call_module")
is_match = is_match and (last_nodes[1].op == "call_function")
is_match = is_match and (last_nodes[2].op == "call_module")
is_match = is_match and isinstance(modules[last_nodes[0].target], torch.nn.Conv2d)
is_match = is_match and (str(last_nodes[1]).split("_")[0] == "add")
is_match = is_match and isinstance(modules[last_nodes[2].target], torch.nn.ReLU)
if (is_match):
conv = modules[last_nodes[1].args[0].target]
fused_conv = Conv2dAddReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
replace_node_module(last_nodes[2], modules, fused_conv)
last_nodes[2].args = (last_nodes[0].args[0], last_nodes[1].args[1])
new_graph.erase_node(last_nodes[1])
new_graph.erase_node(last_nodes[0])
count += 1
return fx.GraphModule(fx_model, new_graph)