scripts/generate_imports.py (100 lines of code) (raw):

# Copyright (c) Meta Platforms, Inc """ Generates imports files for bijectors, distributions, and parameters. This script assumes that you have used `setup.py develop`. """ import errno import io import os import black import flowtorch import torch from flowtorch.utils import ( classname, copyright_header, list_bijectors, list_distributions, list_parameters, ) copyright_header = "".join(["# " + ln + "\n" for ln in copyright_header.splitlines()]) autogen_msg = """\"\"\" Warning: This file was generated by flowtorch/scripts/generate_imports.py Do not modify or delete! \"\"\" """ bijectors_imports = """import inspect from typing import cast, List, Tuple import torch""" bijectors_code = """def isbijector(cls: type) -> bool: # A class must inherit from flowtorch.Bijector to be considered a valid bijector return issubclass(cls, Bijector) def standard_bijector(cls: type) -> bool: # "Standard bijectors" are the ones we can perform standard automated tests upon return ( inspect.isclass(cls) and isbijector(cls) and cls.__name__ not in [clx for clx, _ in meta_bijectors] ) # Determine invertible bijectors invertible_bijectors = [] for bij_name, cls in standard_bijectors: # TODO: Use factored out version of the following # Define plan for flow event_dim = max(cls.domain.event_dim, 1) # type: ignore event_shape = event_dim * [4] # base_dist = dist.Normal(torch.zeros(event_shape), torch.ones(event_shape)) bij = cls(shape=torch.Size(event_shape)) try: y = torch.randn(*bij.forward_shape(event_shape)) bij.inverse(y) except NotImplementedError: pass else: invertible_bijectors.append((bij_name, cls)) __all__ = ["standard_bijectors", "meta_bijectors", "invertible_bijectors"] + [ cls for cls, _ in cast(List[Tuple[str, Bijector]], meta_bijectors) + cast(List[Tuple[str, Bijector]], standard_bijectors) ]""" mode = black.FileMode() fast = False def generate_imports_plain(filename, classes): with io.StringIO() as file: # Sort classes by qualified name classes = sorted(classes, key=lambda tup: classname(tup[1])) print(copyright_header, file=file) print(autogen_msg, file=file) for s, cls in classes: print(f"from {cls.__module__} import {s}", file=file) print("", file=file) print("__all__ = [", file=file) all_list = ",\n\t".join([f'"{s}"' for s, _ in classes]) print("\t", end="", file=file) print(all_list, file=file) print("]", end="", file=file) contents = file.getvalue() with open(filename, "w") as real_file: print( black.format_file_contents(contents, fast=fast, mode=mode), file=real_file, end="", ) def generate_imports_bijectors(filename): bij = list_bijectors() meta_bijectors = [] standard_bijectors = [] # Standard bijectors can be initialized with a shape, whereas # meta bijectors will throw a TypeError or require additional # keyword arguments (e.g., bij.Compose) # TODO: Refactor this into flowtorch.utils.ismetabijector for b in bij: try: cls = b[1] x = cls(shape=torch.Size([2] * cls.domain.event_dim)) except TypeError: meta_bijectors.append(b) else: if isinstance(x, flowtorch.Lazy): meta_bijectors.append(b) else: standard_bijectors.append(b) with io.StringIO() as file: # Sort classes by qualified name classes = standard_bijectors + meta_bijectors classes = sorted(classes, key=lambda tup: classname(tup[1])) # Copyright header and warning message print(copyright_header, file=file) print(autogen_msg, file=file) # Non-FlowTorch imports print(bijectors_imports, file=file) # FlowTorch imports for s, cls in classes: print(f"from {cls.__module__} import {s}", file=file) print("", file=file) # Create lists of bijectors for each type meta_str = ",\n ".join([f'("{b[0]}", {b[0]})' for b in meta_bijectors]) standard_str = ",\n ".join( [f'("{b[0]}", {b[0]})' for b in standard_bijectors] ) print( f"""standard_bijectors = [ {standard_str} ] """, file=file, ) print( f"""meta_bijectors = [ {meta_str} ] """, file=file, ) # Rest of code print(bijectors_code, file=file) contents = file.getvalue() with open(filename, "w") as real_file: print( black.format_file_contents(contents, fast=fast, mode=mode), file=real_file, end="", ) if __name__ == "__main__": # Create module folders if they don't exist try: for m in ["distributions", "bijectors", "parameters"]: os.makedirs(os.path.join(flowtorch.__path__[0], m)) except OSError as e: if e.errno != errno.EEXIST: raise bijectors_init = os.path.join(flowtorch.__path__[0], "bijectors/__init__.py") distributions_init = os.path.join( flowtorch.__path__[0], "distributions/__init__.py" ) parameters_init = os.path.join(flowtorch.__path__[0], "parameters/__init__.py") generate_imports_bijectors(bijectors_init) generate_imports_plain(distributions_init, list_distributions()) generate_imports_plain(parameters_init, list_parameters())