optimum/executorch/passes/remove_padding_idx_embedding_pass.py (13 lines of code) (raw):
import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult
class RemovePaddingIdxEmbeddingPass(ExportPass):
"""
An ExportPass that removes the `padding_idx` keyword argument
from all aten.embedding.default operator calls.
"""
def __init__(self) -> None:
super().__init__()
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
if node.op == "call_function" and node.target == exir_ops.edge.aten.embedding.default:
# node.args[2] is the padding_idx
if len(node.args) == 3:
node.args = (node.args[0], node.args[1])
graph_module.recompile()
return PassResult(graph_module, True)