# Copyright (c) Meta Platforms, Inc

"""
Generates imports files for bijectors, distributions, and parameters.
This script assumes that you have used `setup.py develop`.

"""

import errno
import io
import os

import black
import flowtorch
import torch
from flowtorch.utils import (
    classname,
    copyright_header,
    list_bijectors,
    list_distributions,
    list_parameters,
)

copyright_header = "".join(["# " + ln + "\n" for ln in copyright_header.splitlines()])

autogen_msg = """\"\"\"
Warning: This file was generated by flowtorch/scripts/generate_imports.py
Do not modify or delete!

\"\"\"
"""

bijectors_imports = """import inspect
from typing import cast, List, Tuple

import torch"""

bijectors_code = """def isbijector(cls: type) -> bool:
    # A class must inherit from flowtorch.Bijector to be considered a valid bijector
    return issubclass(cls, Bijector)


def standard_bijector(cls: type) -> bool:
    # "Standard bijectors" are the ones we can perform standard automated tests upon
    return (
        inspect.isclass(cls)
        and isbijector(cls)
        and cls.__name__ not in [clx for clx, _ in meta_bijectors]
    )

# Determine invertible bijectors
invertible_bijectors = []
for bij_name, cls in standard_bijectors:
    # TODO: Use factored out version of the following
    # Define plan for flow
    event_dim = max(cls.domain.event_dim, 1)  # type: ignore
    event_shape = event_dim * [4]
    # base_dist = dist.Normal(torch.zeros(event_shape), torch.ones(event_shape))
    bij = cls(shape=torch.Size(event_shape))

    try:
        y = torch.randn(*bij.forward_shape(event_shape))
        bij.inverse(y)
    except NotImplementedError:
        pass
    else:
        invertible_bijectors.append((bij_name, cls))


__all__ = ["standard_bijectors", "meta_bijectors", "invertible_bijectors"] + [
    cls
    for cls, _ in cast(List[Tuple[str, Bijector]], meta_bijectors)
    + cast(List[Tuple[str, Bijector]], standard_bijectors)
]"""

mode = black.FileMode()
fast = False


def generate_imports_plain(filename, classes):
    with io.StringIO() as file:
        # Sort classes by qualified name
        classes = sorted(classes, key=lambda tup: classname(tup[1]))

        print(copyright_header, file=file)
        print(autogen_msg, file=file)
        for s, cls in classes:
            print(f"from {cls.__module__} import {s}", file=file)

        print("", file=file)

        print("__all__ = [", file=file)
        all_list = ",\n\t".join([f'"{s}"' for s, _ in classes])
        print("\t", end="", file=file)
        print(all_list, file=file)
        print("]", end="", file=file)

        contents = file.getvalue()

    with open(filename, "w") as real_file:
        print(
            black.format_file_contents(contents, fast=fast, mode=mode),
            file=real_file,
            end="",
        )


def generate_imports_bijectors(filename):
    bij = list_bijectors()
    meta_bijectors = []
    standard_bijectors = []

    # Standard bijectors can be initialized with a shape, whereas
    # meta bijectors will throw a TypeError or require additional
    # keyword arguments (e.g., bij.Compose)
    # TODO: Refactor this into flowtorch.utils.ismetabijector
    for b in bij:
        try:
            cls = b[1]
            x = cls(shape=torch.Size([2] * cls.domain.event_dim))
        except TypeError:
            meta_bijectors.append(b)
        else:
            if isinstance(x, flowtorch.Lazy):
                meta_bijectors.append(b)
            else:
                standard_bijectors.append(b)

    with io.StringIO() as file:
        # Sort classes by qualified name
        classes = standard_bijectors + meta_bijectors
        classes = sorted(classes, key=lambda tup: classname(tup[1]))

        # Copyright header and warning message
        print(copyright_header, file=file)
        print(autogen_msg, file=file)

        # Non-FlowTorch imports
        print(bijectors_imports, file=file)

        # FlowTorch imports
        for s, cls in classes:
            print(f"from {cls.__module__} import {s}", file=file)
        print("", file=file)

        # Create lists of bijectors for each type
        meta_str = ",\n    ".join([f'("{b[0]}", {b[0]})' for b in meta_bijectors])
        standard_str = ",\n    ".join(
            [f'("{b[0]}", {b[0]})' for b in standard_bijectors]
        )

        print(
            f"""standard_bijectors = [
    {standard_str}
]
""",
            file=file,
        )

        print(
            f"""meta_bijectors = [
    {meta_str}
]
""",
            file=file,
        )

        # Rest of code
        print(bijectors_code, file=file)

        contents = file.getvalue()

    with open(filename, "w") as real_file:
        print(
            black.format_file_contents(contents, fast=fast, mode=mode),
            file=real_file,
            end="",
        )


if __name__ == "__main__":
    # Create module folders if they don't exist
    try:
        for m in ["distributions", "bijectors", "parameters"]:
            os.makedirs(os.path.join(flowtorch.__path__[0], m))
    except OSError as e:
        if e.errno != errno.EEXIST:
            raise

    bijectors_init = os.path.join(flowtorch.__path__[0], "bijectors/__init__.py")
    distributions_init = os.path.join(
        flowtorch.__path__[0], "distributions/__init__.py"
    )
    parameters_init = os.path.join(flowtorch.__path__[0], "parameters/__init__.py")

    generate_imports_bijectors(bijectors_init)
    generate_imports_plain(distributions_init, list_distributions())
    generate_imports_plain(parameters_init, list_parameters())
