tinynn/util/util.py (126 lines of code) (raw):
import importlib.util
import logging
import os
import sys
import types
import typing
from functools import wraps
loggers = []
class LazyExpression(object):
"""An expression object that can be lazily evaluated"""
expr: typing.Callable[[], typing.Any]
def __init__(self, expr: typing.Callable[[], typing.Any]):
"""Constructs a new LazyExpression object
Args:
expr (typing.Callable[[], typing.Any]): the expression
"""
self.expr = expr
def eval(self) -> typing.Any:
"""Evaluates the lazy expression
Returns:
typing.Any: the value of the expression
"""
return self.expr()
class LazyObject(object):
"""An object that can be construct lazily using lazy expressions"""
cls: type
pos_options: list
kw_options: typing.Dict[str, typing.Any]
current_obj: typing.Any
def __init__(
self,
cls: type,
pos_options: typing.Optional[list] = None,
kw_options: typing.Optional[typing.Dict[str, typing.Any]] = None,
):
"""Constructs a new lazy object
Args:
cls (type): The type of the class to be construct lazily
pos_options (typing.Optional[list], optional): The positional options. Defaults to None.
kw_options (typing.Optional[typing.Dict[str, typing.Any]], optional): The keyword options. Defaults to None.
"""
self.cls = cls
if pos_options is None:
self.pos_options = []
else:
self.pos_options = pos_options
if kw_options is None:
self.kw_options = {}
else:
self.kw_options = kw_options
# Try construct object once
self.get_next()
def __str__(self) -> str:
"""Proxy __str__ to the underlying the object
Returns:
str: String representation of the underlying object
"""
return str(self.current_obj)
def __repr__(self) -> str:
"""Proxy __repr__ to the underlying the object
Returns:
str: String representation of the underlying object
"""
return self.current_obj.__repr__()
def get_current(self) -> typing.Any:
"""Gets the current copy of the lazily-constructed object
Returns:
typing.Any: The current copy of the lazily-constructed object
"""
return self.current_obj
def get_next(self):
"""Create a new object with the class and the arguments"""
pos_options = []
kw_options = {}
for opt in self.pos_options:
if isinstance(opt, LazyObject):
pos_options.append(opt.get_current())
elif isinstance(opt, LazyExpression):
pos_options.append(opt.eval())
else:
pos_options.append(opt)
for opt_k, opt_v in self.kw_options.items():
if isinstance(opt_v, LazyObject):
kw_options[opt_k] = opt.get_current()
elif isinstance(opt_v, LazyExpression):
kw_options[opt_k] = opt_v.eval()
else:
kw_options[opt_k] = opt_v
self.current_obj = self.cls(*pos_options, **kw_options)
return self.current_obj
def get_actual_type(param_type: type) -> typing.List[type]:
"""Gets the candidate actual types of a given type
Args:
param_type (type): The type to extract actual type from
Returns:
typing.List[type]: The candidate types
"""
param_types = []
if hasattr(param_type, '__origin__'):
if param_type.__origin__ is typing.Union:
for arg in param_type.__args__:
param_types.extend(get_actual_type(arg))
else:
param_types.append(param_type.__origin__)
else:
param_types.append(param_type)
return param_types
def conditional(cond: typing.Callable[[], bool]) -> typing.Callable:
"""A function wrapper that only runs the code of the function under given condition
Args:
cond (typing.Callable[[], bool]): The predicate given
Returns:
typing.Callable: A new function that only runs under given condition
"""
def conditional_decorator(f):
@wraps(f)
def wrapper(*args, **kwds):
if cond():
return f(*args, **kwds)
return wrapper
return conditional_decorator
def class_conditional(cond: typing.Callable[[typing.Any], bool], default_val=None) -> typing.Callable:
"""A class function wrapper that only runs the code of the function under given condition
Args:
cond (typing.Callable[[], bool]): The predicate given
Returns:
typing.Callable: A new function that only runs under given condition
"""
def conditional_decorator(f):
@wraps(f)
def wrapper(*args, **kwds):
if cond(args[0]):
return f(*args, **kwds)
else:
return default_val
return wrapper
return conditional_decorator
def tensors2ndarray(tensors) -> list:
"""Convert tensors in arbitrary format into list of ndarray
Args:
output: tensors in arbitrary format
Returns:
typing.List[numpy.ndarray]: list of ndarray
"""
new_tensors = []
if isinstance(tensors, (list, tuple)):
for i in tensors:
new_tensors.extend(tensors2ndarray(i))
elif isinstance(tensors, dict):
for k, v in tensors.items():
new_tensors.extend(tensors2ndarray(v))
elif hasattr(tensors, 'detach') and hasattr(tensors, 'numpy'):
new_tensors.append(tensors.detach().numpy())
elif not isinstance(tensors, (bool, str, int, float, types.FunctionType)):
for k in tensors.__dir__():
if not k.startswith('__'):
v = getattr(tensors, v)
new_tensors.extend(tensors2ndarray(v))
return new_tensors
def import_from(module: str, name: str):
"""Import a module with the given name (Equivalent of `from module import name`)
Args:
module (str): The namespace of the module
name (str): The name to be imported
Returns:
The imported class, function, and etc
"""
module = __import__(module, fromlist=[name])
return getattr(module, name)
def import_from_path(module: str, path: str, name: str):
"""Import a module with the given name and path (Equivalent of `from module import name`)
Args:
module (str): The namespace of the module
path (str): The path of the file
name (str): The name to be imported
Returns:
The imported class, function, and etc
"""
spec = importlib.util.spec_from_file_location(module, path)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)
sys.modules[module] = foo
return getattr(foo, name)
def get_logger(name: str, level: typing.Optional[str] = None) -> logging.Logger:
"""Acquires the logger for a module with the given name
Args:
name (str): The name of the module
level (typing.Optional[str]): The level of the logger. Defaults to None ('INFO').
Returns:
logging.Logger: The logger of the module
"""
if level is None:
level = 'INFO'
level = os.environ.get('LOGLEVEL', level)
logger = logging.getLogger(name)
logger.setLevel(level)
# Initialze the log level of the logger. Other possible values are `INFO`, `DEBUG` and `ERROR`
logging.basicConfig(format='%(levelname)s (%(name)s) %(message)s')
loggers.append(logger)
return logger
def set_global_log_level(level: str = "INFO"):
for logger in loggers:
logger.setLevel(level)