def call()

in optimum/executorch/passes/remove_padding_idx_embedding_pass.py [0:0]


    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
        for node in graph_module.graph.nodes:
            if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default:
                # node.args[2] is the padding_idx
                if len(node.args) == 3:
                    node.args = (node.args[0], node.args[1])
        graph_module.recompile()
        return PassResult(graph_module, True)