def auto_import()

in tzrec/utils/load_class.py [0:0]


def auto_import(user_path=None):
    """Auto import python files.

    So that register_xxx decorator will take effect.
    By default, we will import files in pre-defined directory and import all
    files recursively in user_dir

    Args:
        user_path: directory or file that store user-defined python code,
            by default we will only search file in current directory
    """
    # True False indicates import recursively or not
    pre_defined_dirs = [
        ("tzrec/models", False),
        ("tzrec/datasets", False),
        ("tzrec/features", False),
    ]

    curr_dir, _ = os.path.split(__file__)
    parent_dir = os.path.dirname(os.path.dirname(curr_dir))
    prefix_to_remove = None
    # dealing with when in sited-packages, remove parent directory prefix
    if parent_dir != "":
        for idx in range(len(pre_defined_dirs)):
            pre_defined_dirs[idx] = (
                os.path.join(parent_dir, pre_defined_dirs[idx][0]),
                pre_defined_dirs[idx][1],
            )
        prefix_to_remove = parent_dir + "/"

    if user_path is not None:
        if os.path.isdir(user_path):
            user_dir = user_path
        else:
            user_dir, _ = os.path.split(user_path)
        pre_defined_dirs.append((user_dir, True))

    for dir_path, recursive_import in pre_defined_dirs:
        for pkg_info in pkgutil.iter_modules([dir_path]):
            import_pkg(pkg_info, prefix_to_remove)

        if recursive_import:
            for root, dirs, _ in os.walk(dir_path):
                for subdir in dirs:
                    dirname = os.path.join(root, subdir)
                    for pkg_info in pkgutil.iter_modules([dirname]):
                        import_pkg(pkg_info, prefix_to_remove)