scripts/generate_imports.py (100 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc
"""
Generates imports files for bijectors, distributions, and parameters.
This script assumes that you have used `setup.py develop`.
"""
import errno
import io
import os
import black
import flowtorch
import torch
from flowtorch.utils import (
classname,
copyright_header,
list_bijectors,
list_distributions,
list_parameters,
)
copyright_header = "".join(["# " + ln + "\n" for ln in copyright_header.splitlines()])
autogen_msg = """\"\"\"
Warning: This file was generated by flowtorch/scripts/generate_imports.py
Do not modify or delete!
\"\"\"
"""
bijectors_imports = """import inspect
from typing import cast, List, Tuple
import torch"""
bijectors_code = """def isbijector(cls: type) -> bool:
# A class must inherit from flowtorch.Bijector to be considered a valid bijector
return issubclass(cls, Bijector)
def standard_bijector(cls: type) -> bool:
# "Standard bijectors" are the ones we can perform standard automated tests upon
return (
inspect.isclass(cls)
and isbijector(cls)
and cls.__name__ not in [clx for clx, _ in meta_bijectors]
)
# Determine invertible bijectors
invertible_bijectors = []
for bij_name, cls in standard_bijectors:
# TODO: Use factored out version of the following
# Define plan for flow
event_dim = max(cls.domain.event_dim, 1) # type: ignore
event_shape = event_dim * [4]
# base_dist = dist.Normal(torch.zeros(event_shape), torch.ones(event_shape))
bij = cls(shape=torch.Size(event_shape))
try:
y = torch.randn(*bij.forward_shape(event_shape))
bij.inverse(y)
except NotImplementedError:
pass
else:
invertible_bijectors.append((bij_name, cls))
__all__ = ["standard_bijectors", "meta_bijectors", "invertible_bijectors"] + [
cls
for cls, _ in cast(List[Tuple[str, Bijector]], meta_bijectors)
+ cast(List[Tuple[str, Bijector]], standard_bijectors)
]"""
mode = black.FileMode()
fast = False
def generate_imports_plain(filename, classes):
with io.StringIO() as file:
# Sort classes by qualified name
classes = sorted(classes, key=lambda tup: classname(tup[1]))
print(copyright_header, file=file)
print(autogen_msg, file=file)
for s, cls in classes:
print(f"from {cls.__module__} import {s}", file=file)
print("", file=file)
print("__all__ = [", file=file)
all_list = ",\n\t".join([f'"{s}"' for s, _ in classes])
print("\t", end="", file=file)
print(all_list, file=file)
print("]", end="", file=file)
contents = file.getvalue()
with open(filename, "w") as real_file:
print(
black.format_file_contents(contents, fast=fast, mode=mode),
file=real_file,
end="",
)
def generate_imports_bijectors(filename):
bij = list_bijectors()
meta_bijectors = []
standard_bijectors = []
# Standard bijectors can be initialized with a shape, whereas
# meta bijectors will throw a TypeError or require additional
# keyword arguments (e.g., bij.Compose)
# TODO: Refactor this into flowtorch.utils.ismetabijector
for b in bij:
try:
cls = b[1]
x = cls(shape=torch.Size([2] * cls.domain.event_dim))
except TypeError:
meta_bijectors.append(b)
else:
if isinstance(x, flowtorch.Lazy):
meta_bijectors.append(b)
else:
standard_bijectors.append(b)
with io.StringIO() as file:
# Sort classes by qualified name
classes = standard_bijectors + meta_bijectors
classes = sorted(classes, key=lambda tup: classname(tup[1]))
# Copyright header and warning message
print(copyright_header, file=file)
print(autogen_msg, file=file)
# Non-FlowTorch imports
print(bijectors_imports, file=file)
# FlowTorch imports
for s, cls in classes:
print(f"from {cls.__module__} import {s}", file=file)
print("", file=file)
# Create lists of bijectors for each type
meta_str = ",\n ".join([f'("{b[0]}", {b[0]})' for b in meta_bijectors])
standard_str = ",\n ".join(
[f'("{b[0]}", {b[0]})' for b in standard_bijectors]
)
print(
f"""standard_bijectors = [
{standard_str}
]
""",
file=file,
)
print(
f"""meta_bijectors = [
{meta_str}
]
""",
file=file,
)
# Rest of code
print(bijectors_code, file=file)
contents = file.getvalue()
with open(filename, "w") as real_file:
print(
black.format_file_contents(contents, fast=fast, mode=mode),
file=real_file,
end="",
)
if __name__ == "__main__":
# Create module folders if they don't exist
try:
for m in ["distributions", "bijectors", "parameters"]:
os.makedirs(os.path.join(flowtorch.__path__[0], m))
except OSError as e:
if e.errno != errno.EEXIST:
raise
bijectors_init = os.path.join(flowtorch.__path__[0], "bijectors/__init__.py")
distributions_init = os.path.join(
flowtorch.__path__[0], "distributions/__init__.py"
)
parameters_init = os.path.join(flowtorch.__path__[0], "parameters/__init__.py")
generate_imports_bijectors(bijectors_init)
generate_imports_plain(distributions_init, list_distributions())
generate_imports_plain(parameters_init, list_parameters())