tzrec/utils/load_class.py (95 lines of code) (raw):

# Copyright (c) 2024, Alibaba Group; # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import pkgutil import pydoc import traceback from abc import ABCMeta def import_pkg(pkg_info, prefix_to_remove=None): """Import package. Args: pkg_info: pkgutil.ModuleInfo object prefix_to_remove: the package prefix to be removed """ package_path = pkg_info[0].path if prefix_to_remove is not None: package_path = package_path.replace(prefix_to_remove, "") mod_name = pkg_info[1] if package_path.startswith("/"): # absolute path file, we should use relative import mod = pkg_info[0].find_module(mod_name) if mod is not None: # skip those test files if not mod_name.endswith("_test"): mod.load_module(pkg_info[1]) else: raise Exception("import module %s failed" % (package_path + mod_name)) else: # use similar import methods as the import keyword module_path = os.path.join(package_path, mod_name).replace("/", ".") # skip those test files if not mod_name.endswith("_test"): try: __import__(module_path) except Exception as e: raise ValueError( "import module %s failed: %s" % (module_path, str(e)) ) from e def auto_import(user_path=None): """Auto import python files. So that register_xxx decorator will take effect. By default, we will import files in pre-defined directory and import all files recursively in user_dir Args: user_path: directory or file that store user-defined python code, by default we will only search file in current directory """ # True False indicates import recursively or not pre_defined_dirs = [ ("tzrec/models", False), ("tzrec/datasets", False), ("tzrec/features", False), ] curr_dir, _ = os.path.split(__file__) parent_dir = os.path.dirname(os.path.dirname(curr_dir)) prefix_to_remove = None # dealing with when in sited-packages, remove parent directory prefix if parent_dir != "": for idx in range(len(pre_defined_dirs)): pre_defined_dirs[idx] = ( os.path.join(parent_dir, pre_defined_dirs[idx][0]), pre_defined_dirs[idx][1], ) prefix_to_remove = parent_dir + "/" if user_path is not None: if os.path.isdir(user_path): user_dir = user_path else: user_dir, _ = os.path.split(user_path) pre_defined_dirs.append((user_dir, True)) for dir_path, recursive_import in pre_defined_dirs: for pkg_info in pkgutil.iter_modules([dir_path]): import_pkg(pkg_info, prefix_to_remove) if recursive_import: for root, dirs, _ in os.walk(dir_path): for subdir in dirs: dirname = os.path.join(root, subdir) for pkg_info in pkgutil.iter_modules([dirname]): import_pkg(pkg_info, prefix_to_remove) def register_class(class_map, class_name, cls): """Register a class into class_map. Args: class_map: class register map. class_name: name of the class. cls: a class. """ assert class_name not in class_map or class_map[class_name] == cls, ( f"confilict class {cls} , " f"{class_name} is already register to be {class_map[class_name]}" ) class_map[class_name] = cls def get_register_class_meta(class_map): """Get a meta class with registry. Args: class_map: class register map. Return: a meta class with registry. """ class RegisterABCMeta(ABCMeta): def __new__(mcs, name, bases, attrs): newclass = super(RegisterABCMeta, mcs).__new__(mcs, name, bases, attrs) register_class(class_map, name, newclass) @classmethod def create_class(cls, name): if name in class_map: return class_map[name] else: raise Exception( "Class %s is not registered. Available ones are %s" % (name, list(class_map.keys())) ) newclass.create_class = create_class return newclass return RegisterABCMeta def load_by_path(path): """Load functions or modules or classes. Args: path: path to modules or functions or classes, such as: torch.nn.ReLU Return: modules or functions or classes """ path = path.strip() if path == "" or path is None: return None if "lambda" in path: return eval(path) components = path.split(".") if components[0] == "nn": components.insert(0, "torch") path = ".".join(components) try: return pydoc.locate(path) except pydoc.ErrorDuringImport: print("load %s failed: %s" % (path, traceback.format_exc())) return None