in optimum/fx/optimization/transformations.py [0:0]
def transform(self, graph_module: "GraphModule") -> "GraphModule":
for node in graph_module.graph.nodes:
if node.op == "call_module" and node.args[0].op == "call_module":
if (
type(graph_module.get_submodule(node.target)) is torch.nn.BatchNorm1d
and type(graph_module.get_submodule(node.args[0].target)) is torch.nn.Linear
):
# handle the case torch.nn.Linear --> torch.nn.BatchNorm1d
if len(node.args[0].users) > 1: # Output of linear is used by other nodes
continue
candidate_linear = graph_module.get_submodule(node.args[0].target)
candidate_batchnorm1d = graph_module.get_submodule(node.target)
# will fuse only if the linear output features is equal to the batchnorm num features, this is the case with 2D tensors
# the case where the linear input is (N, C, L_in), output is (N, C, L_out) and C = L_out is NOT handled as can not be fused
if candidate_linear.weight.shape[0] == candidate_batchnorm1d.weight.shape[0]:
fused_linear = self.fuse(
linear=candidate_linear, bn1d=candidate_batchnorm1d, bn1d_before=False
)
# replace the old nn.Linear by the fused one
parent_name, _, name = node.args[0].target.rpartition(".")
parent_module = graph_module.get_submodule(parent_name)
setattr(parent_module, name, fused_linear)
# delete batchnorm from the modules
parent_name, _, name = node.target.rpartition(".")
parent_module = graph_module.get_submodule(parent_name)
delattr(parent_module, name)
node.replace_all_uses_with(node.args[0])
graph_module.graph.erase_node(node) # delete BatchNorm1d
elif (
type(graph_module.get_submodule(node.target)) is torch.nn.Linear
and type(graph_module.get_submodule(node.args[0].target)) is torch.nn.BatchNorm1d
):
# handle the case torch.nn.BatchNorm1d --> torch.nn.Linear
if len(node.args[0].users) > 1: # Output of batchnorm is used by other nodes
continue
candidate_linear = graph_module.get_submodule(node.target)
candidate_batchnorm1d = graph_module.get_submodule(node.args[0].target)
# will fuse only if the linear input features is equal to the batchnorm num features, this is the case with 2D tensors
# the case where the linear input is (N, C, L_in) and C = L_in is NOT handled as can not be fused
if candidate_batchnorm1d.weight.shape[0] == candidate_linear.weight.shape[1]:
fused_linear = self.fuse(linear=candidate_linear, bn1d=candidate_batchnorm1d, bn1d_before=True)
# replace the old nn.Linear by the fused one
parent_name, _, name = node.target.rpartition(".")
parent_module = graph_module.get_submodule(parent_name)
setattr(parent_module, name, fused_linear)
# delete batchnorm from the modules
parent_name, _, name = node.args[0].target.rpartition(".")
parent_module = graph_module.get_submodule(parent_name)
delattr(parent_module, name)
batchnorm_node = node.args[0]
node.args[0].replace_all_uses_with(node.args[0].args[0])
graph_module.graph.erase_node(batchnorm_node) # delete BatchNorm1d
return graph_module