scripts/convert-types.py (235 lines of code) (raw):

#!/usr/bin/env python # Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Script to convert type hints to follow PEP-0585 For more information, see https://peps.python.org/pep-0585 To run from the repository's root directory: python convert-types.py """ from __future__ import annotations from collections.abc import Callable, Iterator import difflib from glob import glob import logging import re import sys from typing import NamedTuple, TypeVar # TODO: # - False positives with lambdas and dict comprehensions due to the `:` misinterpreted as a type hint # - Sort imports case insensitive (e.g. PIL, Flask) # - Type hint arguments can be lists like `Callable[[a, b], c]` # Cases not covered: # - Multi-line imports like `from M import (\nA,\nB,\n)` # - Importing `typing` directly like `import typing` and `x: typing.Any` # - Parsing types with `|` syntax like Union or Optional # - typing.re.Match --> re.Match # - typing.re.Pattern --> re.Pattern BUILTIN_TYPES = {"Tuple", "List", "Dict", "Set", "FrozenSet", "Type"} COLLECTIONS_TYPES = {"Deque", "DefaultDict", "OrderedDict", "Counter", "ChainMap"} COLLECTIONS_ABC_TYPES = { "Awaitable", "Coroutine", "AsyncIterable", "AsyncIterator", "AsyncGenerator", "Iterable", "Iterator", "Generator", "Reversible", "Container", "Collection", "Callable", "AbstractSet", "MutableSet", "Mapping", "MutableMapping", "Sequence", "MutableSequence", "ByteString", "MappingView", "KeysView", "ItemsView", "ValuesView", } CONTEXTLIB_TYPES = {"ContextManager", "AsyncContextManager"} RE_TYPES = {"Match", "Pattern"} RENAME_TYPES = { "Tuple": "tuple", # builtin "List": "list", # builtin "Dict": "dict", # builtin "Set": "set", # builtin "FrozenSet": "frozenset", # builtin "Type": "type", # builtin "Deque": "deque", # collections "DefaultDict": "defaultdict", # collections "AbstractSet": "Set", # collections.abc "ContextManager": "AbstractContextManager", # contextlib "AsyncContextManager": "AbstractAsyncContextManager", # contextlib } # Parser a = String -> (a, String) a = TypeVar("a") Parser = Callable[[str], tuple[a, str]] class TypeHint(NamedTuple): name: str args: list[TypeHint] def __repr__(self) -> str: match (self.name, self.args): case ("Optional", [x]): return f"{x} | None" case ("Union", args): return " | ".join(map(str, args)) case (name, []): return name case (name, args): return f"{name}[{', '.join(map(str, args))}]" def patch(self, types: set[str]) -> TypeHint: if self.name in types: name = RENAME_TYPES.get(self.name, self.name) else: name = self.name return TypeHint(name, [arg.patch(types) for arg in self.args]) def patch_file(file_path: str, dry_run: bool = False, quiet: bool = False) -> None: with open(file_path) as f: before = f.read() try: lines = [line.rstrip() for line in before.splitlines()] if types := find_typing_imports(lines): lines = insert_import_annotations(lines) lines = [patched for line in lines for patched in patch_imports(line)] lines = sort_imports(lines) after = patch_type_hints("\n".join(lines), types) + "\n" if before == after: return if not dry_run: with open(file_path, "w") as f: f.write(after) print(file_path) elif not quiet: print(f"| {file_path}") print(f"+--{'-' * len(file_path)}") diffs = difflib.context_diff( before.splitlines(keepends=True), after.splitlines(keepends=True), fromfile="Before changes", tofile="After changes", n=1, ) sys.stdout.writelines(diffs) print(f"+{'=' * 100}") print("| Press [ENTER] to continue to the next file") input() except Exception: logging.exception(f"Could not process file: {file_path}") def insert_import_annotations(lines: list[str]) -> list[str]: new_import = "from __future__ import annotations" if new_import in lines: return lines match find_import(lines): case None: return lines case i: if lines[i].startswith("from __future__ import "): return lines[:i] + [new_import] + lines[i:] return lines[:i] + [new_import, ""] + lines[i:] def find_typing_imports(lines: list[str]) -> set[str]: return { name.strip() for line in lines if line.startswith("from typing import ") for name in line.split("import")[1].split(",") } def find_import(lines: list[str]) -> int | None: for i, line in enumerate(lines): if line.startswith(("import ", "from ")): return i return None def get_imports_group(lines: list[str]) -> tuple[list[str], list[str]]: for i, line in enumerate(lines): if not line.strip() or line.startswith("#"): return (lines[:i], lines[i:]) return ([], lines) def import_name(line: str) -> str: match line.split(): case ["import", name, *_]: return name case ["from", name, "import", *_]: return name raise ValueError(f"not an import: {line}") def sort_imports(lines: list[str]) -> list[str]: match find_import(lines): case None: return lines case i: (imports, left) = get_imports_group(lines[i:]) if imports: return lines[:i] + sorted(imports, key=import_name) + sort_imports(left) return left def patch_imports(line: str) -> Iterator[str]: if not line.startswith("from typing import "): yield line return types = find_typing_imports([line]) collections_types = types.intersection(COLLECTIONS_TYPES) collections_abc_types = types.intersection(COLLECTIONS_ABC_TYPES) contextlib_types = types.intersection(CONTEXTLIB_TYPES) re_types = types.intersection(RE_TYPES) typing_types = ( types - BUILTIN_TYPES - COLLECTIONS_TYPES - COLLECTIONS_ABC_TYPES - CONTEXTLIB_TYPES - RE_TYPES - {"Optional", "Union"} ) rename = lambda name: RENAME_TYPES.get(name, name) if collections_types: names = sorted(map(rename, collections_types)) yield f"from collections import {', '.join(names)}" if collections_abc_types: names = sorted(map(rename, collections_abc_types)) yield f"from collections.abc import {', '.join(names)}" if contextlib_types: names = sorted(map(rename, contextlib_types)) yield f"from contextlib import {', '.join(names)}" if re_types: names = sorted(map(rename, re_types)) yield f"from re import {', '.join(names)}" if typing_types: names = sorted(map(rename, typing_types)) yield f"from typing import {', '.join(names)}" def patch_type_hints(txt: str, types: set[str]) -> str: if m := re.search(rf"(?:->|:) *(\w+)", txt): (typ, left) = parse_type_hint(txt[m.start(1) :]) return f"{txt[:m.start(1)]}{typ.patch(types)}{patch_type_hints(left, types)}" return txt # Parser combinators def parse_text(src: str, txt: str) -> tuple[str, str]: if src.startswith(txt): return (src[: len(txt)], src[len(txt) :]) raise SyntaxError("text") def parse_identifier(src: str) -> tuple[str, str]: if m := re.search(r"[\w\._]+", src): return (m.group(), src[m.end() :]) raise SyntaxError("name") def parse_zero_or_more(src: str, parser: Parser[a]) -> tuple[list[a], str]: try: (x, src) = parser(src) (xs, src) = parse_zero_or_more(src, parser) return ([x] + xs, src) except SyntaxError: return ([], src) def parse_comma_separated(src: str, parser: Parser[a]) -> tuple[list[a], str]: def parse_next(src: str) -> tuple[a, str]: (_, src) = parse_text(src, ",") (_, src) = parse_zero_or_more(src, lambda src: parse_text(src, " ")) return parser(src) try: (x, src) = parser(src) (xs, src) = parse_zero_or_more(src, parse_next) return ([x] + xs, src) except SyntaxError: return ([], src) def parse_type_hint(src: str) -> tuple[TypeHint, str]: (name, src) = parse_identifier(src) try: (_, src) = parse_text(src, "[") (args, src) = parse_comma_separated(src, parse_type_hint) (_, src) = parse_text(src, "]") return (TypeHint(name, args), src) except SyntaxError: return (TypeHint(name, []), src) def run(patterns: list[str], dry_run: bool = False, quiet: bool = False): for pattern in patterns: for filename in glob(pattern, recursive=True): patch_file(filename, dry_run, quiet) if __name__ == "__main__": import argparse assert sys.version_info.major == 3, "Requires Python 3" assert sys.version_info.minor >= 10, "Requires Python >= 3.10 for pattern matching" parser = argparse.ArgumentParser() parser.add_argument("patterns", nargs="*", default=["**/*.py"]) parser.add_argument("--dry-run", action="store_true") parser.add_argument("--quiet", action="store_true") args = parser.parse_args() run(**args.__dict__)