in scripts/generate_imports.py [0:0]
def generate_imports_bijectors(filename):
bij = list_bijectors()
meta_bijectors = []
standard_bijectors = []
# Standard bijectors can be initialized with a shape, whereas
# meta bijectors will throw a TypeError or require additional
# keyword arguments (e.g., bij.Compose)
# TODO: Refactor this into flowtorch.utils.ismetabijector
for b in bij:
try:
cls = b[1]
x = cls(shape=torch.Size([2] * cls.domain.event_dim))
except TypeError:
meta_bijectors.append(b)
else:
if isinstance(x, flowtorch.Lazy):
meta_bijectors.append(b)
else:
standard_bijectors.append(b)
with io.StringIO() as file:
# Sort classes by qualified name
classes = standard_bijectors + meta_bijectors
classes = sorted(classes, key=lambda tup: classname(tup[1]))
# Copyright header and warning message
print(copyright_header, file=file)
print(autogen_msg, file=file)
# Non-FlowTorch imports
print(bijectors_imports, file=file)
# FlowTorch imports
for s, cls in classes:
print(f"from {cls.__module__} import {s}", file=file)
print("", file=file)
# Create lists of bijectors for each type
meta_str = ",\n ".join([f'("{b[0]}", {b[0]})' for b in meta_bijectors])
standard_str = ",\n ".join(
[f'("{b[0]}", {b[0]})' for b in standard_bijectors]
)
print(
f"""standard_bijectors = [