def fuse_conv_relu()

in nestedtensor/nested/fuser.py [0:0]


def fuse_conv_relu(model: torch.nn.Module, inplace=False) -> torch.nn.Module:
    """
    Fuses convolution/BN layers for inference purposes. Will deepcopy your
    model by default, but can modify the model inplace as well.
    """
    patterns = [(torch.nn.Conv2d, torch.nn.ReLU)]
    if not inplace:
        model = copy.deepcopy(model)
    fx_model = fx.symbolic_trace(model)
    modules = dict(fx_model.named_modules())
    new_graph = copy.deepcopy(fx_model.graph)

    for pattern in patterns:
        for node in new_graph.nodes:
            if matches_module_pattern(pattern, node, modules):
                if len(node.args[0].users) > 1:  # Output of conv is used by other nodes
                    continue
                conv = modules[node.args[0].target]
                relu = modules[node.target]
                fused_conv = Conv2dReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
                replace_node_module(node.args[0], modules, fused_conv)
                node.replace_all_uses_with(node.args[0])
                new_graph.erase_node(node)


    last_nodes = []
    count = 0
    for node in new_graph.nodes:
        if count == 31:
            break
        if (node.op == "call_function" or node.op == "call_module"):
            last_nodes.append(node)
            if len(last_nodes) == 4:
                last_nodes = last_nodes[1:]
        if len(last_nodes) < 3:
            continue
        is_match = True
        is_match = is_match and (last_nodes[0].op == "call_module")
        is_match = is_match and (last_nodes[1].op == "call_function")
        is_match = is_match and (last_nodes[2].op == "call_module")
        is_match = is_match and isinstance(modules[last_nodes[0].target], torch.nn.Conv2d)
        is_match = is_match and (str(last_nodes[1]).split("_")[0] == "add")
        is_match = is_match and isinstance(modules[last_nodes[2].target], torch.nn.ReLU)
        if (is_match):
            conv = modules[last_nodes[1].args[0].target]
            fused_conv = Conv2dAddReLU(conv.weight, conv.bias, conv.stride, conv.padding, conv.dilation, conv.groups)
            replace_node_module(last_nodes[2], modules, fused_conv)
            last_nodes[2].args = (last_nodes[0].args[0], last_nodes[1].args[1])
            new_graph.erase_node(last_nodes[1])
            new_graph.erase_node(last_nodes[0])
            count += 1
    return fx.GraphModule(fx_model, new_graph)