def transform()

in optimum/fx/optimization/transformations.py [0:0]


    def transform(self, graph_module: "GraphModule") -> "GraphModule":
        for node in graph_module.graph.nodes:
            if node.op == "call_module" and node.args[0].op == "call_module":
                if (
                    type(graph_module.get_submodule(node.target)) is torch.nn.BatchNorm1d
                    and type(graph_module.get_submodule(node.args[0].target)) is torch.nn.Linear
                ):
                    # handle the case torch.nn.Linear --> torch.nn.BatchNorm1d

                    if len(node.args[0].users) > 1:  # Output of linear is used by other nodes
                        continue

                    candidate_linear = graph_module.get_submodule(node.args[0].target)
                    candidate_batchnorm1d = graph_module.get_submodule(node.target)

                    # will fuse only if the linear output features is equal to the batchnorm num features, this is the case with 2D tensors
                    # the case where the linear input is (N, C, L_in), output is (N, C, L_out) and C = L_out is NOT handled as can not be fused
                    if candidate_linear.weight.shape[0] == candidate_batchnorm1d.weight.shape[0]:
                        fused_linear = self.fuse(
                            linear=candidate_linear, bn1d=candidate_batchnorm1d, bn1d_before=False
                        )

                        # replace the old nn.Linear by the fused one
                        parent_name, _, name = node.args[0].target.rpartition(".")
                        parent_module = graph_module.get_submodule(parent_name)
                        setattr(parent_module, name, fused_linear)

                        # delete batchnorm from the modules
                        parent_name, _, name = node.target.rpartition(".")
                        parent_module = graph_module.get_submodule(parent_name)
                        delattr(parent_module, name)

                        node.replace_all_uses_with(node.args[0])

                        graph_module.graph.erase_node(node)  # delete BatchNorm1d
                elif (
                    type(graph_module.get_submodule(node.target)) is torch.nn.Linear
                    and type(graph_module.get_submodule(node.args[0].target)) is torch.nn.BatchNorm1d
                ):
                    # handle the case torch.nn.BatchNorm1d --> torch.nn.Linear
                    if len(node.args[0].users) > 1:  # Output of batchnorm is used by other nodes
                        continue

                    candidate_linear = graph_module.get_submodule(node.target)
                    candidate_batchnorm1d = graph_module.get_submodule(node.args[0].target)

                    # will fuse only if the linear input features is equal to the batchnorm num features, this is the case with 2D tensors
                    # the case where the linear input is (N, C, L_in) and C = L_in is NOT handled as can not be fused
                    if candidate_batchnorm1d.weight.shape[0] == candidate_linear.weight.shape[1]:
                        fused_linear = self.fuse(linear=candidate_linear, bn1d=candidate_batchnorm1d, bn1d_before=True)

                        # replace the old nn.Linear by the fused one
                        parent_name, _, name = node.target.rpartition(".")
                        parent_module = graph_module.get_submodule(parent_name)
                        setattr(parent_module, name, fused_linear)

                        # delete batchnorm from the modules
                        parent_name, _, name = node.args[0].target.rpartition(".")
                        parent_module = graph_module.get_submodule(parent_name)
                        delattr(parent_module, name)

                        batchnorm_node = node.args[0]
                        node.args[0].replace_all_uses_with(node.args[0].args[0])

                        graph_module.graph.erase_node(batchnorm_node)  # delete BatchNorm1d
        return graph_module