utils/check_static_imports.py (74 lines of code) (raw):

# coding=utf-8 # Copyright 2022-present, the HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains a tool to reformat static imports in `huggingface_hub.__init__.py`.""" import argparse import os import re import tempfile from pathlib import Path from typing import NoReturn from ruff.__main__ import find_ruff_bin from huggingface_hub import _SUBMOD_ATTRS INIT_FILE_PATH = Path(__file__).parents[1] / "src" / "huggingface_hub" / "__init__.py" IF_TYPE_CHECKING_LINE = "\nif TYPE_CHECKING: # pragma: no cover\n" SUBMOD_ATTRS_PATTERN = re.compile("_SUBMOD_ATTRS = {[^}]+}") # match the all dict def check_static_imports(update: bool) -> NoReturn: """Check all imports are made twice (1 in lazy-loading and 1 in static checks). For more explanations, see `./src/huggingface_hub/__init__.py`. This script is used in the `make style` and `make quality` checks. """ with INIT_FILE_PATH.open() as f: init_content = f.read() # Get first half of the `__init__.py` file. # WARNING: Content after this part will be entirely re-generated which means # human-edited changes will be lost ! init_content_before_static_checks = init_content.split(IF_TYPE_CHECKING_LINE)[0] # Search and replace `_SUBMOD_ATTRS` dictionary definition. This ensures modules # and functions that can be lazy-loaded are alphabetically ordered for readability. if SUBMOD_ATTRS_PATTERN.search(init_content_before_static_checks) is None: print("Error: _SUBMOD_ATTRS dictionary definition not found in `./src/huggingface_hub/__init__.py`.") exit(1) _submod_attrs_definition = ( "_SUBMOD_ATTRS = {\n" + "\n".join( f' "{module}": [\n' + "\n".join(f' "{attr}",' for attr in sorted(set(_SUBMOD_ATTRS[module]))) + "\n ]," for module in sorted(set(_SUBMOD_ATTRS.keys())) ) + "\n}" ) reordered_content_before_static_checks = SUBMOD_ATTRS_PATTERN.sub( _submod_attrs_definition, init_content_before_static_checks ) # Generate the static imports given the `_SUBMOD_ATTRS` dictionary. static_imports = [ f" from .{module} import {attr} # noqa: F401" for module, attributes in _SUBMOD_ATTRS.items() for attr in attributes ] # Generate the expected `__init__.py` file content and apply formatter on it. with tempfile.TemporaryDirectory() as tmpdir: filepath = Path(tmpdir) / "__init__.py" filepath.write_text( reordered_content_before_static_checks + IF_TYPE_CHECKING_LINE + "\n".join(static_imports) + "\n" ) ruff_bin = find_ruff_bin() os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "check", str(filepath), "--fix", "--quiet"]) os.spawnv(os.P_WAIT, ruff_bin, ["ruff", "format", str(filepath), "--quiet"]) expected_init_content = filepath.read_text() # If expected `__init__.py` content is different, test fails. If '--update-init-file' # is used, `__init__.py` file is updated before the test fails. if init_content != expected_init_content: if update: with INIT_FILE_PATH.open("w") as f: f.write(expected_init_content) print( "✅ Imports have been updated in `./src/huggingface_hub/__init__.py`." "\n Please make sure the changes are accurate and commit them." ) exit(0) else: print( "❌ Expected content mismatch in" " `./src/huggingface_hub/__init__.py`.\n It is most likely that you" " added a module/function to `_SUBMOD_ATTRS` and did not update the" " 'static import'-part.\n Please run `make style` or `python" " utils/check_static_imports.py --update`." ) exit(1) print("✅ All good! (static imports)") exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--update", action="store_true", help="Whether to fix `./src/huggingface_hub/__init__.py` if a change is detected.", ) args = parser.parse_args() check_static_imports(update=args.update)