flowtorch/utils.py (60 lines of code) (raw):
# Copyright (c) Meta Platforms, Inc
import importlib
import inspect
import os
import pkgutil
from functools import partial
from typing import Any, Callable, Optional, Sequence, Tuple
import flowtorch
from flowtorch.bijectors.base import Bijector
from flowtorch.parameters.base import Parameters
from torch.distributions import Distribution
copyright_header = """Copyright (c) Meta Platforms, Inc"""
def classname(cls: type) -> str:
return ".".join([cls.__module__, cls.__name__])
def issubclass_byname(cls: type, test_cls: type) -> bool:
"""
Test whether a class is a subclass of another by class names, in contrast
to the built-in issubclass that does it by instance.
"""
return classname(test_cls) in [classname(c) for c in cls.__mro__]
def isderivedclass(cls: type, base_cls: type) -> bool:
# NOTE issubclass won't always do what we want here if base_cls is imported
# inside the module of cls. I.e. issubclass returns False if cls inherits
# from a base_cls with a different instance.
return inspect.isclass(cls) and issubclass_byname(cls, base_cls)
def list_bijectors() -> Sequence[Tuple[str, Bijector]]:
ans = _walk_packages("bijectors", partial(isderivedclass, base_cls=Bijector))
ans = [a for a in ans if ".ops." not in a[1].__module__]
return list({classname(cls[1]): cls for cls in ans}.values())
def list_parameters() -> Sequence[Tuple[str, Parameters]]:
ans = _walk_packages("parameters", partial(isderivedclass, base_cls=Parameters))
return list({classname(cls[1]): cls for cls in ans}.values())
def list_distributions() -> Sequence[Tuple[str, Parameters]]:
ans = _walk_packages(
"distributions", partial(isderivedclass, base_cls=Distribution)
)
return list({classname(cls[1]): cls for cls in ans}.values())
def _walk_packages(
modname: str, filter: Optional[Callable[[Any], bool]]
) -> Sequence[Tuple[str, Any]]:
classes = []
# NOTE: I use path of flowtorch rather than e.g. flowtorch.bijectors
# to avoid circular imports
path = [os.path.join(flowtorch.__path__[0], modname)] # type: ignore
# The followings line uncovered a bug that hasn't been fixed in mypy:
# https://github.com/python/mypy/issues/1422
for importer, this_modname, _ in pkgutil.walk_packages(
path=path, # type: ignore # mypy issue #1422
prefix=f"{flowtorch.__name__}.{modname}.",
onerror=lambda x: None,
):
# Conditions required for mypy
if importer is not None:
if isinstance(importer, importlib.abc.MetaPathFinder):
finder = importer.find_module(this_modname, None)
elif isinstance(importer, importlib.abc.PathEntryFinder):
finder = importer.find_module(this_modname)
else:
finder = None
if finder is not None:
module = finder.load_module(this_modname)
else:
raise Exception("Finder is none")
if module is not None:
this_classes = inspect.getmembers(module, filter)
classes.extend(this_classes)
del module
del finder
else:
raise Exception("Module is none")
return classes
class InterfaceError(Exception):
pass