optimum/exporters/executorch/recipe_registry.py (17 lines of code) (raw):
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import pkgutil
logger = logging.getLogger(__name__)
recipe_registry = {}
package_name = "optimum.exporters.executorch.recipes"
def register_recipe(recipe_name):
"""
Decorator to register a recipe for exporting and lowering an ExecuTorch model under a specific name.
Args:
recipe_name (`str`):
The name of the recipe to associate with a callable recipe.
Returns:
`Callable`:
The original function wrapped as a registered recipe.
Example:
```python
@register_recipe("my_new_recipe")
def my_new_recipe(...):
...
```
"""
def decorator(func):
recipe_registry[recipe_name] = func
return func
return decorator
def discover_recipes():
"""
Dynamically discovers and imports all recipe modules within the `optimum.exporters.executorch.recipes` package.
Ensures recipes under `./recipes` directory are dynamically loaded without requiring manual imports.
Notes:
New recipes **must** be added to the `./recipes` directory to be discovered and used by `main_export`.
Failure to do so will prevent dynamic discovery and registration. Recipes must also use the
`@register_recipe` decorator to be properly registered in the `recipe_registry`.
"""
package = importlib.import_module(package_name)
package_path = package.__path__
for _, module_name, _ in pkgutil.iter_modules(package_path):
logger.info(f"Importing {package_name}.{module_name}")
importlib.import_module(f"{package_name}.{module_name}")