tinynn/graph/tracer.py (2,599 lines of code) (raw):

import contextlib import copy import ctypes import importlib import inspect import io import os import queue import re import sys import traceback import typing import weakref import types import torch import torch.nn as nn import yaml import numpy as np from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.distributed import DistributedDataParallel from tinynn.util.train_util import get_module_device from tinynn.util.util import get_logger, import_from_path, tensors2ndarray from ._utils import patch_getitem, revert_getitem, patch_new, revert_new from . import interop # noqa: F401 # Basic types class GlobalData(object): """The data structure to store data that can be used in this script, which is a wrapper of a object of a built-in type.""" def __init__(self, value): super().__init__() self.value = value def get_value(self): """Returns the inner value of the wrapper""" return self.value def set_value(self, value): """Sets the inner value of the wrapper""" self.value = value def __str__(self): """Returns the string representation of the inner object""" return self.value.__str__() def __repr__(self): """Returns the string representation of the inner object""" return self.value.__repr__() def __call__(self, *args): """Simplifies the usage of the wrapper e.g. a = GlobalData(3) a() -> a.get_value() a(1) -> a.set_value(1)""" if len(args) == 0: return self.get_value() elif len(args) == 1: return self.set_value(*args) else: raise ValueError( f'length of input data must in [0,1], but got length: {len(args)} --> args: {args}' ) def __bool__(self): """Returns the actual boolean value of the inner object""" return self.value.__bool__() # Constants # Template for a traced module MODULE_TEMPLATE = """ %(import_block)s class %(name_block)s(torch.nn.Module): def __init__(self): super().__init__() %(init_block)s %(forward_block)s if __name__ == "__main__": model = %(name_block)s() %(load_weight_block)s model.eval() model.cpu() %(input_block)s output = model(%(input_names)s) print(output) """ # Special math operators SPECIAL_OPERATORS = ['add', 'and', 'div', 'floordiv', 'lshift', 'mul', 'or', 'pow', 'rshift', 'sub', 'xor', 'truediv'] # Global objects # Logger log = get_logger(__name__, 'WARNING') # Loaded overriable items from the config file overridable_funcs = {} overridable_modules = [] overridable_creation_funcs = {} torch_overrides_funcs = [] torch_overrides_wrappers = [] torch_tracking_modules = [] # Reuse generated wrapper functions and modules generated_wrapper_funcs = {} generated_wrapper_modules = {} # Load state for the override items overridable_funcs_loaded = GlobalData(False) overridable_modules_loaded = GlobalData(False) overridable_creation_funcs_loaded = GlobalData(False) torch_overrides_funcs_loaded = GlobalData(False) tracking_modules_loaded = GlobalData(False) funcs_overrided = GlobalData(False) modules_overrided = GlobalData(False) creation_funcs_overrided = GlobalData(False) tracking_modules_overrided = GlobalData(False) # Lock for tracing lock = GlobalData(False) handle_func_lock = GlobalData(False) # Whether the constructors get traced module_constructor_traced = set() # Current traced graph current_graph = GlobalData(None) # Generated module constructor lines module_constructor_lines = {} module_constructor_weakrefs = {} # Directory of the current script current_dir = os.path.dirname(os.path.abspath(__file__)) # Original module constructor signatures module_constructor_signatures = {} # Original values of tracked objects original_values_for_tracked_objects = {} # Original module class names importable_module_names = {} # Ignore warning for update module parameters mod_param_update_warning_ignore = GlobalData(False) # Modules that are skipped while tracing skip_modules = set() class TraceNode(object): """A basic data structure to represent a node in the computation graph""" module: typing.Union[torch.nn.Module, 'TraceFunction', 'ConstantNode'] prev_nodes: typing.List['TraceNode'] next_nodes: typing.List['TraceNode'] prev_tensors: typing.List[torch.Tensor] next_tensors: typing.List[torch.Tensor] prev_indices: typing.List[typing.Optional[int]] rev_index: bool unique_name: str active: bool def __init__( self, module: typing.Union[torch.nn.Module, 'ConstantNode', 'TraceFunction'], cur_graph: typing.Optional['TraceGraph'] = None, ): # Inner module, could either be a `nn.Module`, `ConstantNode` or `TraceFunction` self.module = module # Previous and next nodes in the computation graph self.prev_nodes = [] self.next_nodes = [] # The input and output tensors for the node self.prev_tensors = [] self.next_tensors = [] # The indices used to retrieve the corresponding tensor # e.g. torch.chunk() returns [A, B], in which A and B are PyTorch tensors. # so if we use A in this node, then the corresponding prev_index is 0. # If the tensor is not a sub item, then `None` should be used. self.prev_indices = [] # In some nodes, the indices are reversed. For example, for an output node, # the indices are not used to fetch the items, but to construct a list that # contains them. self.rev_index = False # The current TraceGraph to be processed # In the trace phase, it can be obtained through `current_graph()` # Otherwise, you need to pass it explicitly if cur_graph is None: cur_graph = current_graph() # Unique name of the node (the key of the node in the node map in TraceGraph) if type(module) in (ConstantNode, TraceFunction): self.unique_name = module.unique_name else: self.unique_name = cur_graph.module_unique_name_dict[id(module)] if isinstance(module, nn.Module) and id(module) in cur_graph.module_original_name_dict: self.original_name = cur_graph.module_original_name_dict[id(module)] elif type(module) is ConstantNode: self.original_name = module.original_name else: self.original_name = self.unique_name # Whether the node is active in the computation graph self.active = True # The index of the node in the graph self.forward_order = 0 # Whether the node is in a quantized graph self.quantized = False # Numbering of the name of the node if cur_graph.global_nodes.get(self.unique_name) is not None: cur_graph.global_nodes[self.unique_name] += 1 self.unique_name = "_".join([self.unique_name, str(cur_graph.global_nodes[self.unique_name])]) else: cur_graph.global_nodes[self.unique_name] = 0 def type(self): """Returns the original name of the function or the type of the module""" if type(self.module) is TraceFunction: return self.module.func_type return type(self.module) def kind(self): """Returns the kind of the function or the type of the module""" if type(self.module) is TraceFunction: return self.module.kind return type(self.module) def is_class(self) -> bool: """Judges whether it is a class function or not""" if type(self.module) is TraceFunction: return self.module.is_class else: return False def full_name(self) -> str: """Returns the original full name of the function (including namespace)""" if type(self.module) in (TraceFunction, ConstantNode): return self.module.full_name else: return f'{type(self.module).__module__}.{type(self.module).__name__}' def __hash__(self) -> str: """Uses the unique name as the hash for the node""" return self.unique_name def prev_node_unique_name(self, idx, inplace=False) -> str: """A utility function to generate the name of the previous node with index""" if idx < len(self.prev_nodes) and idx < len(self.prev_indices): getattr_on_module = False if ( isinstance(self.prev_nodes[idx].module, torch.nn.Module) and type(self.module) is TraceFunction and self.module.is_property and '.' not in self.module.full_name ): getattr_on_module = True actual_inplace = False if inplace: actual_inplace = getattr_on_module if isinstance(self.prev_nodes[idx].module, ConstantNode): actual_inplace = True if actual_inplace: node_name = self.prev_nodes[idx].original_name else: node_name = self.prev_nodes[idx].unique_name node_idx = self.prev_indices[idx] ns = '' if type(self.prev_nodes[idx].module) in (ConstantNode, torch.nn.quantized.FloatFunctional): ns = 'self.' elif getattr_on_module: prev_t_ids = set(id(t) for t in self.prev_tensors) next_t_ids = set(id(t) for t in self.prev_nodes[idx].next_tensors) if len(prev_t_ids & next_t_ids) == 0: ns = 'self.' if node_idx is None: return f'{ns}{node_name}' else: if isinstance(node_idx, (list, tuple)): indices_str = ''.join([f'[{i}]' for i in node_idx]) return f'{ns}{node_name}{indices_str}' else: return f'{ns}{node_name}[{node_idx}]' else: return '' class ConstantNode(object): """A data structure for runtime-defined constants""" def __init__( self, data: typing.List, dtype: torch.dtype, shape: torch.Size, unique_name: typing.Optional[str] = None, original_name: typing.Optional[str] = None, ): # Raw data (list) self.data = data # Data shape self.shape = tuple(shape) # Data type self.dtype = str(dtype) # Please refer to the the description of those properties in `TraceFunction` self.kind = 'tensor' self.func_type = 'tensor' self.full_name = 'torch.tensor' self.is_class = False self.is_parameter = False self.is_persistent = False self.requires_grad = False self.data_str = None # Numbering of the name of the node if current_graph().global_functions.get(self.kind, None) is None: current_graph().global_functions[self.kind] = 0 else: current_graph().global_functions[self.kind] += 1 if unique_name is None: self.unique_name = "_".join([self.kind, str(current_graph().global_functions[self.kind])]) else: self.unique_name = unique_name self.inplace = original_name is not None if original_name is None: self.original_name = self.unique_name else: self.original_name = original_name def parse(self, convert_to_parameter: bool = False, persistent: bool = False, requires_grad: bool = False): def _stringify_list(content) -> str: """Convert a list of objects to a string""" if isinstance(content, (list, tuple)): sub_contents = [] for item in content: sub_contents.append(_stringify_list(item)) inner_content = ', '.join(sub_contents) return f'[{inner_content}]' elif type(content) in (int, float, bool): return str(content) elif type(content) is str: return f'"{content}"' # If `convert_to_parameter` is `True`, the content of the data will not be written inline. self.is_parameter = convert_to_parameter self.is_persistent = persistent self.requires_grad = requires_grad if not persistent: self.data_str = f'{_stringify_list(self.data)}' return self class TraceFunction(object): """A data structure for traced functions""" def __init__( self, full_name: str, is_class: bool = False, is_property: bool = False, prefix: typing.Optional[str] = None ): super().__init__() # The base name of the function self.func_type = full_name.split('.')[-1] # The class name of the function # It can be acquired by removing underlines in the base name of the function for special math # operators and inline functions self.kind = None if self.func_type.endswith('__') and self.func_type.startswith('__'): inner_name = self.func_type[2:-2] if len(inner_name) > 1 and inner_name[0] in ('i', 'r'): inner_op = inner_name[1:] if inner_op in SPECIAL_OPERATORS: self.kind = inner_op if self.kind is None: self.kind = inner_name if self.kind is None: if self.func_type.endswith('_'): self.kind = self.func_type[:-1] else: self.kind = self.func_type # Numbering of the nodes if current_graph().global_functions.get(self.kind, None) is None: current_graph().global_functions[self.kind] = 0 else: current_graph().global_functions[self.kind] += 1 if prefix is None: prefix = "" # Unique name self.unique_name = prefix + "_".join([self.kind, str(current_graph().global_functions[self.kind])]) + "_f" # The input tensors of the function self.prev_tensors = [] # The name of the function (including namespace) self.full_name = full_name # Whether it is a class function/property self.is_class = is_class # Whether it is a property self.is_property = is_property # Alias self.aliases = None # Arguments self.args = None self.kwargs = None self.args_string = None self.args_parsed = None self.args_parsed_origin = None self.tensor_names = None self.original_tensor_names = None self.args_template = None self.args_template_no_self = None self.args_offset = None def __repr__(self) -> str: """Returns the string representation of the object""" arg_len = len(self.tensor_names) if arg_len > 0: prefix = 'lambda args: ' else: prefix = 'lambda: ' extra_expr = self.extra_expr('args') expr = f'{prefix}{extra_expr}' return expr def extra_expr(self, prefix=None, first=None, original=False): """Returns the extra string representation of the object""" arg_len = len(self.tensor_names) if arg_len == 0: expr = f'{self.full_name}({self.args_template})' else: if prefix is None: if original: tensor_names = [ (o_name if o_name.startswith('self.') else f'self.{o_name}') if u_name.startswith('self.') else u_name for u_name, o_name in zip(self.tensor_names, self.original_tensor_names) ] else: tensor_names = self.tensor_names else: tensor_names = [f'{prefix}[{i}]' for i in range(arg_len)] if first is not None: tensor_names = [first] + tensor_names[1:] if self.is_class: if self.is_property: expr = f'{tensor_names[0]}.{self.func_type}' elif self.func_type == '__getitem__': args = self.args_string_no_self if not args.startswith('[') and not args.endswith(']'): args = f'[{args}]' expr = f'{tensor_names[0]}{args}' else: full_template = f'{{}}.{self.func_type}({self.args_template_no_self})' expr = full_template.format(*tensor_names) else: full_template = f'{self.full_name}({self.args_template})' expr = full_template.format(*tensor_names) return expr def __call__(self, *args, **kwargs): """Calls the function with a list of tensor inputs""" if len(kwargs) > 0: log.warning('Keyword arguments are ignored when calling TraceFunction') if len(args) == 0: arg_len = 0 else: if len(args) > 1: log.warning('Multiple arguments are passed in, but all but the first one will be ignored') if not isinstance(args[0], (tuple, list)): log.error('Only tuple or list is accepted here') assert False arg_len = len(args[0]) expected_len = len(self.tensor_names) if arg_len != expected_len: log.error(f'Wrong number of input tensors, expected: {expected_len}, but got {arg_len}') assert False expr = self.extra_expr('args[0]') return eval(expr) def parse_args(self, *args, **kwargs): """Sets the string representation of the arguments""" def _tensor_name(a, convert_to_parameter=False, original=False): """Get the tensor name from the computation graph""" ns = '' if constant_handler(a, self.unique_name, self.full_name): ns = 'self.' pre_node_name = current_graph().tensor_pre_node_dict[id(a)] if original: node = current_graph().nodes_map[pre_node_name] pre_node_name = node.original_name else: pre_node_name = current_graph().tensor_pre_node_dict[id(a)] node = current_graph().nodes_map[pre_node_name] if original: pre_node_name = node.original_name if type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional): ns = 'self.' if id(a) in current_graph().tensor_pre_index_dict: pre_node_index = current_graph().tensor_pre_index_dict[id(a)] log.debug(f'pre_index gen func {self.kind}: {pre_node_index}') if isinstance(pre_node_index, (list, tuple)): indices_str = ''.join([f'[{i}]' for i in pre_node_index]) return f"{ns}{pre_node_name}{indices_str}" else: return f"{ns}{pre_node_name}[{pre_node_index}]" else: return f"{ns}{pre_node_name}" def _escape_arg(arg: str): """Escapes the special characters in the argument string""" for c in ('{', '}'): if c in arg: arg = arg.replace(c, f'{c}{c}') return arg def _parse_args(arg): """Converts the argument to a list of strings""" new_arg = [] for a in arg: if isinstance(a, (list, tuple, torch.Size)): new_arg.append(_parse_args(a)) elif type(a) in (torch.Tensor, torch.nn.Parameter) or ( type(a) in (torch.dtype, torch.device, torch.Size) and id(a) in current_graph().tensor_pre_node_dict ): self.prev_tensors.append(a) self.tensor_names.append(_tensor_name(a)) self.original_tensor_names.append(_tensor_name(a, original=True)) new_arg.append('{}') elif type(a) in (str, torch.device): new_arg.append(_escape_arg(f"\'{a}\'")) elif type(a) in (int, bool, torch.dtype): new_arg.append(str(a)) elif type(a) is float: str_arg = str(a) if str_arg in ('nan', 'inf', '-inf'): new_arg.append(f"float('{str_arg}')") else: new_arg.append(str_arg) elif a is None: new_arg.append('None') elif a is Ellipsis: new_arg.append('...') elif type(a) is slice: t = (a.start, a.stop, a.step) parts = [] for x in t: if x is None: parts.append('') else: parts.extend(_parse_args([x])) r = ':'.join(parts) if r.endswith(':'): r = r[:-1] new_arg.append(r) elif isinstance(a, torch.nn.quantized.FloatFunctional): float_functional_cls = type(a) module_constructor_lines[id(a)] = f'{qualified_name(float_functional_cls, short=True)}()' new_node = TraceNode(a) current_graph().nodes_map[new_node.unique_name] = new_node current_graph().other_init_nodes.append(new_node) current_graph().tensor_pre_node_dict[id(a)] = new_node.unique_name self.tensor_names.append(_tensor_name(a)) self.original_tensor_names.append(_tensor_name(a, original=True)) self.prev_tensors.append(a) new_arg.append('{}') elif isinstance(a, nn.Module): unique_name = current_graph().module_unique_name_dict[id(a)] current_graph().tensor_pre_node_dict[id(a)] = unique_name self.tensor_names.append(f'self.{unique_name}') self.original_tensor_names.append(_tensor_name(a, original=True)) self.prev_tensors.append(a) new_arg.append('{}') else: log.error(f"unsupported type {type(a)} while generating arg for func {self.full_name}") assert False return new_arg self.tensor_names = [] self.original_tensor_names = [] self.prev_tensors.clear() arg_str = _parse_args(args) kw_items = kwargs.items() if kw_items: kw_keys, kw_vals = zip(*kw_items) kw_val_strs = _parse_args(kw_vals) for (k, v) in zip(kw_keys, kw_val_strs): if type(v) is list: v_str = self._flatten_list(v) arg_str.append(f"{k}={v_str}") else: arg_str.append(f"{k}={v}") self.args_parsed = copy.deepcopy(arg_str) self.kwargs = copy.deepcopy(kwargs) self.args_parsed_origin = copy.deepcopy(self.args_parsed) self.args_to_string(self.args_parsed) return self def _flatten_list(self, content): """Flatten a list of nested list or string into a string""" if isinstance(content, list): sub_contents = [] for item in content: sub_contents.append(self._flatten_list(item)) inner_content = ', '.join(sub_contents) return f'[{inner_content}]' else: return content def args_to_string(self, arg_str): for i in range(len(arg_str)): if type(arg_str[i]) is list: arg_str[i] = self._flatten_list(arg_str[i]) try: self.args_template = ", ".join(arg_str) self.args_template_no_self = ", ".join(arg_str[1:]) self.args_offset = 1 if arg_str[0] == '{}' else 0 self.update_args_string() except Exception: log.error(f"Error generating argument string for function {self.full_name}") assert False return self def update_tensor_name(self, index, new_name): """Updates the tensor name at the given index""" self.tensor_names[index] = new_name def replace_tensor_name(self, old_name, new_name): """Replaces the specific tensor name with the given one""" for i, name in enumerate(self.tensor_names): if name == old_name: self.tensor_names[i] = new_name def update_args_string(self): """Updates the string representation according to the templates and the tensor names""" if self.args_template: self.args_string = self.args_template.format(*self.tensor_names) if self.args_template_no_self: self.args_string_no_self = self.args_template_no_self.format(*self.tensor_names[self.args_offset :]) def get_tensor_name(self, index, original): """Retrieves the tensor name at the given index""" if original: return self.original_tensor_names[index] else: return self.tensor_names[index] def add_alias(self, name, head=False): """Adds an aliases of the tensor""" if self.aliases is None: self.aliases = [] if head: self.aliases.insert(0, name) else: self.aliases.append(name) def get_aliases(self): """Retrieves the aliases of the tensor""" return self.aliases @contextlib.contextmanager def no_catch(): """Context manager for tracing nodes. Use it to avoid tracing the nodes recursively.""" if lock(): yield False else: lock(True) yield True lock(False) @contextlib.contextmanager def no_catch_handle_func(): """Context manager for tracing nodes. Use it to avoid hacking into handle_torch_function recursively.""" if handle_func_lock(): yield False else: handle_func_lock(True) yield True handle_func_lock(False) def args_as_string(args, kwargs): """String representation of the args and the keyword args""" cleaned_args = [f'"{arg}"' if type(arg) is str else str(arg) for arg in args] args_content = ', '.join(cleaned_args) kwargs_content = ', '.join((f'{k}="{v}"' if type(v) is str else f'{k}={v}' for k, v in kwargs.items())) args_connector = '' if args_content == '' or kwargs_content == '' else ', ' full_args_content = f'{args_content}{args_connector}{kwargs_content}' return full_args_content def new_setattr_gen(orig_setattr, key: str): """Wrapper function for the __setattr__ functions of the modules in PyTorch""" log.debug(f'registered module setattr wrapper: {key}') def new_setattr(obj, name, value): log.debug(f'{key} in setattr function wrapper') related = True log.debug(f'{key} before with block, lock: {lock}') with no_catch() as res: if res: if id(obj) not in module_constructor_traced: related = False class_type = type(obj) if related and not hasattr(class_type, '__constants__'): related = False if related and name not in class_type.__constants__: related = False if related: class_name = '.'.join(key.split('.')[:-1]) log_func = log.warning if mod_param_update_warning_ignore(): log_func = log.debug log_func( f'The constant property `{name}` of {class_name} is changed. We need to drop the original' ' constructor line.' ) module_constructor_traced.remove(id(obj)) del module_constructor_lines[id(obj)] return orig_setattr(obj, name, value) log.debug(f'{key} after with block, lock: {lock}') return new_setattr def new_module_getattr_gen(orig_getattr, key: str, res: typing.Dict[str, bool]): log.debug(f'registered module getattr wrapper: {key}') def new_getattr(obj, name): log.debug(f'{key} in module getattr function wrapper') result = orig_getattr(obj, name) result_is_str = isinstance(result, str) is_dict = name == '__dict__' related = result_is_str or is_dict # If the following conditions are satisfied # a. the type of the result is `str` or the key is `__dict__` # b. the 2nd last frame is on the `extra_repr` function or `__repr__` function # c. the line is not starting with `if ` # then we should patch the variable. if related: last_frame = inspect.currentframe().f_back func_name = last_frame.f_code.co_name related = False if func_name in ('extra_repr', '__repr__'): if is_dict: related = True else: fn = last_frame.f_code.co_filename ln = last_frame.f_lineno res_key = f'{fn}_{ln}' if res_key in res: related = res[res_key] else: lines = inspect.getframeinfo(last_frame)[3] related = not re.match(r' *if .*', lines[0]) and not re.match(r'.*repr\(self\..*\)', lines[0]) res[res_key] = related if related: if result_is_str: return f'"{result}"' elif is_dict: orig = result result = {} for k, v in orig.items(): if type(v) is str and (not k.startswith('__') and not k.endswith('__')): result[k] = f'"{orig[k]}"' else: result[k] = v return result return new_getattr def constant_handler( tensor: torch.Tensor, node_name: typing.Optional[str] = None, type_name: typing.Optional[str] = None ) -> bool: """Appends the constant tensor to the computation graph Args: tensor (torch.Tensor): The constant tensor node_name (typing.Optional[str], optional): The name of the node that depends on that tensor. Defaults to None. type_name (typing.Optional[str], optional): The type of the node that depends on that tensor. Defaults to None. Returns: bool: Whether it is a new constant tensor """ if id(tensor) not in current_graph().tensor_pre_node_dict: if id(tensor) in current_graph().parameter_module_dict: mod_id = current_graph().parameter_module_dict[id(tensor)] unique_name = current_graph().module_unique_name_dict[mod_id] if unique_name in current_graph().related_modules: mod = ctypes.cast(mod_id, ctypes.py_object).value node = TraceNode(mod) add_forward_node(node, [], []) current_graph().tensor_pre_node_dict[mod_id] = node.unique_name key = current_graph().parameter_original_name_dict[id(tensor)].split('.')[-1] trace_func = TraceFunction(key, True, True).parse_args(mod) trace_node = TraceNode(trace_func) add_forward_node(trace_node, [mod], tensor) return False if not tensor.is_leaf: if node_name is None: node_name = 'unknown node' if type_name is None: type_name = 'unknown' log.error(f'Connection is lost when generating code for {node_name} of type {type_name}') else: # constant tensor generation log.warning('Constant generation is experimental and may yield error') convert_to_parameter = False persistent = False requires_grad = tensor.requires_grad if isinstance(tensor, torch.nn.Parameter): convert_to_parameter = True if tensor.numel() > 50: persistent = True raw_data = tensor.tolist() unique_name = current_graph().parameter_unique_name_dict.get(id(tensor), None) original_name = current_graph().parameter_original_name_dict.get(id(tensor), None) with no_catch(): constant_node = ConstantNode(raw_data, tensor.dtype, tensor.shape, unique_name, original_name).parse( convert_to_parameter, persistent, requires_grad ) trace_node = TraceNode(constant_node) add_constant_node(trace_node, tensor) return True else: return False def new_getattr_gen(orig_getattr, key: str, is_class: bool): """Wrapper function for the __getattribute__ functions of the modules in PyTorch""" log.debug(f'registered module getattr wrapper: {key}') def new_getattr(obj, name): log.debug(f'{key} in getattr function wrapper') related = False log.debug(f'{key} before with block, lock: {lock}') with no_catch() as res: result = orig_getattr(obj, name) if current_graph() is None: related = False if name in ('device', 'shape', 'data', 'dtype'): related = True if res: if related: # Also the property should be constant if the result object is unchanged. # Only create a new node when there isn't one. if ( id(result) in current_graph().tensor_pre_node_dict and id(result) in original_values_for_tracked_objects and original_values_for_tracked_objects[id(result)] == result ): node_name = current_graph().tensor_pre_node_dict[id(result)] trace_node = current_graph().nodes_map[node_name] if trace_node.module.is_property and trace_node.module.func_type == 'shape': result = tuple(trace_node.next_tensors) else: # Handling dynamic shape # If the torch.Size object is generated by a tensor, # then we connect it to the graph. # Otherwise, don't track it. old_result = None if type(result) is torch.Size and isinstance(obj, torch.Tensor): # Create a list of new tensors for the sake of tracking # The reason to use that instead of a tensor is stated below. # e.g. Users may use the following clause to deal with sizes # x, y = tensor.size() # Currently, there is no way to trace it. # However, by doing this, if user calls `numel` on the `torch.Size` # object, it will now throw an exception. # TODO: Fix the case if user calls `numel` on `torch.Size` constant_handler(obj, type_name=key) original_values_for_tracked_objects[id(result)] = copy.deepcopy(result) new_result = [] for elem in result: new_result.append(torch.tensor(elem)) current_graph().tensor_pre_node_dict[ id(new_result[-1]) ] = current_graph().tensor_pre_node_dict[id(obj)] old_result = result result = tuple(new_result) log.debug(f'{key} is called with {name}') new_key = key.replace('__getattribute__', name) trace_func = TraceFunction(new_key, True, True).parse_args(obj) trace_node = TraceNode(trace_func) if old_result is not None: current_graph().tensor_pre_node_dict[id(old_result)] = trace_node.unique_name add_forward_node(trace_node, trace_func.prev_tensors, result) log.debug(f'{key} after with block, lock: {lock}') return result return new_getattr def new_init_tracking_gen(orig_init, key: str): """Wrapper function for the init functions of the modules in PyTorch""" log.debug(f'registered module init tracking wrapper: {key}') def new_init_tracking(obj, args, kwargs): log.debug(f'{key} in init tracking function wrapper') with no_catch(): is_tensor = torch.is_tensor(args[0]) if is_tensor: return args[0].unbind(0) else: with no_catch(): return tuple([torch.as_tensor(x) for x in args[0]]) return new_init_tracking def new_init_gen(orig_init, key: str): """Wrapper function for the init functions of the modules in PyTorch""" log.debug(f'registered module init wrapper: {key}') def new_init(obj, *args, **kwargs): log.debug(f'{key} in init function wrapper') module_constructor_traced.add(id(obj)) init_fullname = key class_fullname = '.'.join(init_fullname.split('.')[:-1]) log.debug(f'{key} before with block, lock: {lock}') with no_catch() as res: if id(obj) not in module_constructor_lines or ( id(obj) in module_constructor_weakrefs and module_constructor_weakrefs[id(obj)]() is None ): if not res: log.warning( f'Failed to acquire the tracing lock while tracing {init_fullname}, which is unexpected.' ) log.debug(f'{key} in with block, lock: {lock}') actual_class_name = qualified_name(type(obj)) if actual_class_name == class_fullname: rckwa = check_types(kwargs.values()) rca = check_types(args) err_type = rca or rckwa if err_type: log.warning( f'Constructor of {class_fullname} has arguments of type {err_type} which is unsupported' ) log.warning(f' Args: {args}') log.warning(f' Keyword args: {kwargs}') else: log.info(f'Constructor of {class_fullname} registered') full_args_content = args_as_string(args, kwargs) orig_constructor_line = f'{class_fullname}({full_args_content})' module_constructor_lines[id(obj)] = orig_constructor_line module_constructor_weakrefs[id(obj)] = weakref.ref(obj) else: module_constructor_traced.remove(id(obj)) if not actual_class_name.startswith('torch.'): log.warning(f'Constructor of class {actual_class_name} is not captured') orig_init(obj, *args, **kwargs) log.debug(f'{key} after with block, lock: {lock}') return new_init def new_func_gen(orig_func, key: str, is_class: bool): """Wrapper function for functions in PyTorch""" log.debug(f'registered function wrapper: {key}') def new_func(*args, **kwargs): log.debug(f'{key} in function wrapper') related = False log.debug(f'{key} before with block, lock: {lock}') with no_catch() as res: result = orig_func(*args, **kwargs) if res and current_graph() is not None: log.debug(f'{key} in with block, lock: {lock}') if key == 'torch.Tensor.size' and len(args) > 1: # Tracking torch.Tensor.size with optional int argument result = torch.tensor(result) if type(result) is torch.Size: # Handling dynamic shape # If the torch.Size object is generated by a tensor, # then we connect it to the graph. # Otherwise, don't track it. if len(args) > 0 and type(args[0]) is torch.Tensor: # Create a list of new tensors for the sake of tracking # The reason to use that instead of a tensor is stated below. # e.g. Users may use the following clause to deal with sizes # x, y = tensor.size() # Currently, there is no way to trace it. # However, by doing this, if user calls `numel` on the `torch.Size` # object, it will now throw an exception. # TODO: Fix the case if user calls `numel` on `torch.Size` new_result = [] for elem in result: new_result.append(torch.tensor(elem)) current_graph().tensor_pre_node_dict[ id(new_result[-1]) ] = current_graph().tensor_pre_node_dict[id(args[0])] result = tuple(new_result) related = True elif type(result) in (torch.dtype, torch.device): related = True else: related = check_tensor_type(result) log.debug(f'{key} after with block, lock: {lock}') if related: log.debug(f'tracing {key} in function wrapper') with no_catch() as res: if res: trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs) trace_node = TraceNode(trace_func) modified_result = noop_handler(trace_node, trace_func.prev_tensors, result) if modified_result is not None: result = modified_result add_forward_node(trace_node, trace_func.prev_tensors, result) log.debug(f'tracing {key} function wrapper complete') return result return new_func def new_has_torch_func_gen(orig_func, key: str, is_class: bool): """Wrapper function for functions in PyTorch""" log.debug(f'registered has torch func wrapper: {key}') def new_func(*args, **kwargs): with no_catch_handle_func() as res: return (res and not lock()) or orig_func(*args, **kwargs) return new_func def new_handle_func_gen(orig_func, key: str, is_class: bool): """Wrapper function for functions in PyTorch""" log.debug(f'registered has torch func wrapper: {key}') def new_func(func, tracked_args, *args, **kwargs): if lock(): return orig_func(func, tracked_args, *args, **kwargs) else: with no_catch_handle_func(): return func(*args, **kwargs) return new_func def new_creation_func_gen(orig_func, key: str, is_class: bool): """Wrapper function for functions in PyTorch""" log.debug(f'registered creation function wrapper: {key}') def new_func(*args, **kwargs): log.debug(f'{key} in creation function wrapper') with no_catch() as res: result = orig_func(*args, **kwargs) log.debug(f'tracing {key} in creation function wrapper') with no_catch() as res: if res: trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs) trace_node = TraceNode(trace_func) add_forward_node(trace_node, trace_func.prev_tensors, result) return result return new_func def fetch_modules(config: typing.Optional[str] = None): """Fetches the functions from the config.""" if config is None: config = os.path.join(current_dir, 'configs/torch_module_override.yml') modules = [] with open(config, 'r') as f: module_dict = yaml.load(f, yaml.SafeLoader) for ns, module_names in module_dict.items(): try: scope = importlib.import_module(ns) except ImportError: pass for module_name in module_names: if hasattr(scope, module_name): module = getattr(scope, module_name) modules.append(module) importable_module_names[module] = f'{ns}.{module_name}' if hasattr(module, '__init__'): constructor = module.__init__ module_constructor_signatures[module] = inspect.signature(constructor).parameters.values() return modules def fetch_funcs(config: typing.Optional[str] = None): """Fetches the functions from the config.""" if config is None: version_parts = torch.__version__.split('.') if int(version_parts[0]) == 1: if int(version_parts[1]) < 6: version_parts[1] = '6' if int(version_parts[1]) > 12: version_parts[1] = '12' elif int(version_parts[0]) == 2: if int(version_parts[1]) > 0: version_parts[1] = '0' else: log.warning(f'Your PyTorch version is unsupported: {torch.__version__}') version_parts = ['1', '6'] version_str = '_'.join(version_parts[:2]) config = os.path.join(current_dir, f'configs/torch_func_override_{version_str}.yml') modules = [] with open(config, 'r') as f: module_dict = yaml.load(f, yaml.SafeLoader) new_dict = {} for ns, module_names in module_dict.items(): log.debug(f'Attempting to load {ns}') try: spec = importlib.util.find_spec(ns) except ImportError: continue if spec is None: modules = ns.split('.') ns = '.'.join(modules[:-1]) typename = modules[-1] spec = importlib.util.find_spec(ns) if spec is None: log.warning(f"Error importing {ns}, which may not be a module") continue scope = importlib.import_module(ns) if hasattr(scope, typename): scope = getattr(scope, typename) else: log.warning(f"Error importing {ns}.{typename}") continue else: scope = importlib.import_module(ns) modules = [] for module_name in module_names: if hasattr(scope, module_name): modules.append(module_name) importable_module_names[getattr(scope, module_name)] = f'{ns}.{module_name}' new_dict[scope] = modules return new_dict def prepare_torch_overrides_funcs(funcs): tracked_funcs = [] wrappers = [] if hasattr(torch, 'overrides') and inspect.ismodule(torch.overrides): all_has_torch_func_names = ['has_torch_function', 'has_torch_function_unary', 'has_torch_function_variadic'] all_handle_func_names = ['handle_torch_function'] has_torch_func_names = [] for n in all_has_torch_func_names: if hasattr(torch.overrides, n): has_torch_func_names.append(n) handle_func_names = [] for n in all_handle_func_names: if hasattr(torch.overrides, n): handle_func_names.append(n) has_torch_funcs = {torch.overrides: has_torch_func_names} handle_funcs = {torch.overrides: handle_func_names} for ns in funcs.keys(): if ns == torch.Tensor: if hasattr(torch, '_tensor') and inspect.ismodule(torch._tensor): ns = torch._tensor else: ns = sys.modules['torch.tensor'] ns_has_torch_func_names = [] for k in has_torch_func_names: if hasattr(ns, k): ns_has_torch_func_names.append(k) ns_handle_func_names = [] for k in handle_func_names: if hasattr(ns, k): ns_handle_func_names.append(k) if len(ns_has_torch_func_names) > 0: has_torch_funcs.update({ns: ns_has_torch_func_names}) if len(ns_handle_func_names) > 0: handle_funcs.update({ns: ns_handle_func_names}) tracked_funcs.extend((has_torch_funcs, handle_funcs)) wrappers.extend((new_has_torch_func_gen, new_handle_func_gen)) return tracked_funcs, wrappers def qualified_name(module, item: typing.Optional[str] = None, short: bool = False): if module in importable_module_names: obj_key = importable_module_names[module] elif hasattr(module, '__module__'): obj_key = f'{module.__module__}.{module.__name__}' if short: mod = module.__module__ name = module.__name__ pos = [i for i, x in enumerate(mod) if x == '.'] for i in pos: ns = mod[:i] if ns in sys.modules: cur_mod = sys.modules[ns] if hasattr(cur_mod, name) and getattr(cur_mod, name) == module: obj_key = f'{ns}.{name}' break elif isinstance(module, str): obj_key = module else: obj_key = module.__name__ if item is not None: return f'{obj_key}.{item}' else: return obj_key @contextlib.contextmanager def patch(object, name, gen, *args, **kwargs): """Temporarily monkeypatches an object.""" pre_patched_value = getattr(object, name) setattr(object, name, gen(pre_patched_value, *args, **kwargs)) yield object setattr(object, name, pre_patched_value) @contextlib.contextmanager def patch_modules(objects, names, gens): """Temporarily monkeypatches the modules in PyTorch.""" if not isinstance(names, (tuple, list)) and not isinstance(gens, (tuple, list)): names = (names,) gens = (gens,) pre_patched_values = {} for obj in objects: for name, gen in zip(names, gens): key = qualified_name(obj, name, short=True) pre_patched_values[key] = getattr(obj, name) generated_wrapper_modules.setdefault(key, gen(pre_patched_values[key], key)) if key == 'torch.Size.__new__': pre_patched_values[key] = patch_new(obj, generated_wrapper_modules[key]) else: setattr(obj, name, generated_wrapper_modules[key]) yield objects for obj in objects: for name in names: key = qualified_name(obj, name, short=True) pre_patched_value = pre_patched_values[key] if key == 'torch.Size.__new__': revert_new(obj, pre_patched_value) else: setattr(obj, name, pre_patched_value) @contextlib.contextmanager def patch_funcs(object_dicts, gens): """Temporarily monkeypatches the functions in PyTorch.""" if not isinstance(object_dicts, (tuple, list)) and not isinstance(gens, (tuple, list)): object_dicts = (object_dicts,) gens = (gens,) pre_patched_value_dict = {} for object_dict, gen in zip(object_dicts, gens): for obj, names in object_dict.items(): for name in names: key = qualified_name(obj, name) if key in pre_patched_value_dict: log.warning(f'{key} declared more than once in torch_func_override.yml, skipping') else: if key == 'torch.Tensor.__getitem__': pre_patched_value_dict[key] = getattr(torch._C._TensorBase, '__getitem__') generated_wrapper_funcs.setdefault( key, gen(pre_patched_value_dict[key], key, hasattr(obj, '__module__')) ) new_func = generated_wrapper_funcs[key] key = qualified_name(obj, name) pre_patched_value_dict[key] = patch_getitem(obj, new_func) else: pre_patched_value_dict[key] = getattr(obj, name) generated_wrapper_funcs.setdefault( key, gen(pre_patched_value_dict[key], key, hasattr(obj, '__module__')) ) setattr(obj, name, generated_wrapper_funcs[key]) yield object_dict for object_dict, gen in zip(object_dicts, gens): for obj, names in object_dict.items(): for name in names: key = qualified_name(obj, name) if key == 'torch.Tensor.__getitem__': key = qualified_name(obj, name) revert_getitem(obj, pre_patched_value_dict[key]) else: setattr(obj, name, pre_patched_value_dict[key]) def get_constructor_args(actual_class_type): """Gets the args of the original constructor for a known module class""" if actual_class_type in module_constructor_signatures: return module_constructor_signatures[actual_class_type] else: return inspect.signature(actual_class_type.__init__).parameters.values() def gen_module_constrctor_line(module, mod_cache=None): """Generates the constructor line for a loaded module""" ignored_args = { 'torch.nn.ZeroPad2d': ['value'], 'torch.nn.UpsamplingNearest2d': ['mode'], 'torch.nn.UpsamplingBilinear2d': ['mode'], } legacy_modules = { 'torch.nn.FractionalMaxPool2d', 'torch.nn.FractionalMaxPool3d', 'torch.nn.TransformerEncoderLayer', 'torch.nn.TransformerDecoderLayer', 'torch.nn.Upsample', } def _skip_ignored_args(name, *args, **kwargs): iargs = ignored_args[name] for arg in iargs: if isinstance(arg, int): if arg >= 0 and arg < len(args): del args[arg] elif isinstance(arg, str): if arg in kwargs: del kwargs[arg] else: raise AttributeError(f'Unknown type {type(arg)} in ignored args') return args_as_string(args, kwargs) name = qualified_name(type(module), short=True) if mod_cache is None: mod_cache = {} module_cls = type(module) if name in legacy_modules: if hasattr(module_cls, '__constants__') and hasattr(module_cls, '__init__'): known_constants = set(module_cls.__constants__) arg_info = get_constructor_args(module_cls) args = [] for p in arg_info: prop_name = p.name if prop_name == 'self': continue if prop_name not in known_constants: if not hasattr(module_cls, prop_name) or not isinstance( getattr(module_cls, prop_name), (type(None), str, int, float, bool) ): log.warning( f'Argument "{prop_name}" of the constructor of {name} is not a known constant, skipping' ) continue if p.default is p.empty: prop_value = getattr(module, prop_name) args.append(f'{prop_value}') else: # If loading from legacy model, then the property can be missing. # Thus, we should skip it so as to use the default value. if not hasattr(module, prop_name): continue prop_value = getattr(module, prop_name) # Appending keyword args default_value = p.default # Skip the arg if it has the same value with the default one if default_value == prop_value: continue if type(prop_value) is str: prop_value_str = f'"{prop_value}"' else: prop_value_str = prop_value args.append(f'{prop_name}={prop_value_str}') mid = ', '.join(args) result = f'{name}({mid})' else: with patch(module_cls, '__getattribute__', new_module_getattr_gen, name, mod_cache): if getattr(module_cls, '__repr__') == getattr(torch.nn.Module, '__repr__'): result = f'{name}({module.extra_repr()})' else: ns = '.'.join(name.split('.')[:-1]) result = f'{ns}.{repr(module)}' if name in ignored_args: start_pos = result.find('(') end_pos = result.rfind(')') head = result[: start_pos + 1] tail = result[end_pos:] mid = result[start_pos + 1 : end_pos] result = head + eval(f'_skip_ignored_args("{name}", {mid})') + tail return result, mod_cache def noop_handler(node, inputs, outputs): """Generate modified outputs if the inputs and the outputs are the same""" with no_catch(): is_list = False is_tuple = False modified = False if isinstance(outputs, list): is_list = True elif isinstance(outputs, tuple): is_tuple = True elif isinstance(outputs, torch.Tensor): if id(outputs) in current_graph().tensor_pre_node_dict: return outputs.view(outputs.shape) if is_list or is_tuple: is_tracked = [ isinstance(t, torch.Tensor) and id(t) in current_graph().tensor_pre_node_dict for t in outputs ] modified = any(is_tracked) if modified: new_outputs = [t.view(t.shape) if d else t for t, d in zip(outputs, is_tracked)] if is_tuple: return tuple(new_outputs) else: return new_outputs else: return None def add_input_node(node: TraceNode, output_tensors): """Adds an input node to the current computation graph""" assert node is not None if not isinstance(output_tensors, (list, tuple)): output_tensors = [output_tensors] node.next_tensors.extend(output_tensors) for t in output_tensors: current_graph().tensor_pre_node_dict[id(t)] = node.unique_name current_graph().input_nodes.append(node) current_graph().nodes_map[node.unique_name] = node def add_constant_node(node: TraceNode, output_tensor): """Adds a constant node to the current computation graph""" assert node is not None actual_tensor = output_tensor if isinstance(output_tensor, torch.nn.Parameter): actual_tensor = output_tensor.data current_graph().tensor_parameter_dict[id(output_tensor)] = weakref.ref(actual_tensor) node.next_tensors = [actual_tensor] current_graph().tensor_pre_node_dict[id(output_tensor)] = node.unique_name current_graph().constant_nodes.append(node) current_graph().nodes_map[node.unique_name] = node def add_output_node(node: TraceNode, input_tensors): """Adds an output node to the current computation graph""" assert node is not None need_idx = True if not isinstance(input_tensors, (list, tuple)): input_tensors = [input_tensors] need_idx = False node.prev_tensors.extend(input_tensors) node.rev_index = need_idx for i, t in enumerate(input_tensors): node.prev_nodes.append(current_graph().nodes_map[current_graph().tensor_pre_node_dict[id(t)]]) if id(t) in current_graph().tensor_pre_index_dict: node.prev_indices.append(current_graph().tensor_pre_index_dict[id(t)]) else: node.prev_indices.append(None) current_graph().output_nodes.append(node) current_graph().nodes_map[node.unique_name] = node def add_forward_node(node: TraceNode, input_tensors, output_tensors): """Adds a forward node to the current computation graph""" assert node is not None if not isinstance(input_tensors, (list, tuple)): input_tensors = [input_tensors] need_idx = True if not isinstance(output_tensors, (list, tuple)): output_tensors = [output_tensors] need_idx = False flatten_inputs = [] for t in input_tensors: if isinstance(t, (list, tuple)): for rt in t: flatten_inputs.append(rt) else: flatten_inputs.append(t) node.prev_tensors.extend(flatten_inputs) node.next_tensors.extend(output_tensors) for i, t in enumerate(flatten_inputs): assert type(t) in ( torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter, torch.nn.quantized.FloatFunctional, ) or isinstance(t, torch.nn.Module), ( f'Input #{i} of {node.unique_name}({node.type()}) should be one of the following type ' ' [torch.dtype, torch.device, torch.Size, torch.Tensor,' f' torch.nn.Parameter,torch.nn.quantized.FloatFunctional, torch.nn.Module], but got {type(t)}' ) constant_handler(t, node.unique_name, node.full_name()) pre_node_name = current_graph().tensor_pre_node_dict[id(t)] node.prev_nodes.append(current_graph().nodes_map[pre_node_name]) if id(t) in current_graph().tensor_pre_index_dict: pre_node_index = current_graph().tensor_pre_index_dict[id(t)] log.debug(f'propagate pre_index tensor {pre_node_name} {pre_node_index}') node.prev_indices.append(pre_node_index) else: node.prev_indices.append(None) if isinstance(t, torch.nn.Parameter): if id(t) in current_graph().tensor_parameter_dict: node.prev_tensors[i] = current_graph().tensor_parameter_dict[id(t)]() else: node.prev_tensors[i] = node.prev_tensors[i].data current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.prev_tensors[i]) for i, t in enumerate(output_tensors): if isinstance(t, (list, tuple)): for j, rt in enumerate(t): assert type(rt) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), ( f'Output [{i}][{j}] of {node.unique_name}({node.type()}) should be one of the following type ' f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(rt)}' ) current_graph().tensor_pre_node_dict[id(rt)] = node.unique_name if need_idx: log.debug(f'set pre_index tensor {i}, {j}') current_graph().tensor_pre_index_dict[id(rt)] = [i, j] else: assert type(t) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), ( f'Output #{i} of {node.unique_name}({node.type()}) should be one of the following type ' f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(t)}' ) current_graph().tensor_pre_node_dict[id(t)] = node.unique_name if need_idx: log.debug(f'set pre_index tensor {i}') current_graph().tensor_pre_index_dict[id(t)] = i if isinstance(t, torch.nn.Parameter): if id(t) in current_graph().tensor_parameter_dict: node.next_tensors[i] = current_graph().tensor_parameter_dict[id(t)]() else: node.next_tensors[i] = node.next_tensors[i].data current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.next_tensors[i]) current_graph().forward_nodes.append(node) current_graph().nodes_map[node.unique_name] = node @contextlib.contextmanager def hook_modules(module): """Temporarily adds the hooks to a `nn.Module` for tracing""" hooks = [] def register_submodule_tracer(module): def _submodule_pre_tracer(module, input): log.debug(f'pre tracer in _submodule_pre_tracer in {type(module).__name__}') if lock(): skip_modules.add(weakref.ref(module)) lock(True) def _submodule_tracer(module, inputs, outputs): m_ref = weakref.ref(module) if m_ref in skip_modules: skip_modules.remove(m_ref) return None log.debug(f'tracer in _submodule_tracer in {type(module).__name__}') node = TraceNode(module) modified_outputs = noop_handler(node, inputs, outputs) if modified_outputs is None: add_forward_node(node, inputs, outputs) else: add_forward_node(node, inputs, modified_outputs) lock(False) return modified_outputs module_unique_name = current_graph().module_unique_name_dict[id(module)] if module_unique_name in current_graph().traced_modules: log.debug(f"module {module_unique_name} is traced") return None related = False if id(module) in module_constructor_traced: if ( id(module) in module_constructor_lines and module_constructor_weakrefs.get(id(module), type(None))() is not None ): related = True else: if type(module) in overridable_modules: related = True else: for m in overridable_modules: if isinstance(module, m): related = True break if related: hooks.append(module.register_forward_pre_hook(_submodule_pre_tracer)) hooks.append(module.register_forward_hook(_submodule_tracer)) current_graph().related_modules.append(module_unique_name) current_graph().traced_modules.append(module_unique_name) return None def _model_pre_tracer(module, inputs): log.debug('pre tracer in _model_pre_tracer') for i in inputs: node = TraceNode(TraceFunction("input")) add_input_node(node, i) def _model_tracer(module, inputs, outputs): log.debug('tracer in _model_tracer') if type(outputs) is torch.Tensor: node = TraceNode(TraceFunction("output")) add_output_node(node, outputs) elif isinstance(outputs, (list, tuple)): for i in outputs: if type(i) is torch.Tensor or ( isinstance(i, (list, tuple)) and all((type(x) is torch.Tensor for x in i)) ): node = TraceNode(TraceFunction("output")) add_output_node(node, i) else: log.warning( "Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or" " tuple" ) elif isinstance(outputs, dict): for k, v in outputs.items(): if type(v) is torch.Tensor or ( isinstance(v, (list, tuple)) and all((type(x) is torch.Tensor for x in v)) ): node = TraceNode(TraceFunction("output")) add_output_node(node, v) else: log.warning( "Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or" " tuple" ) else: log.warning(f'Output type is not supported: {type(outputs).__name__}, try to extract tensors from it') for k in outputs.__dir__(): v = getattr(outputs, k) if type(v) is torch.Tensor or (type(v) in (list, tuple) and all((type(x) is torch.Tensor for x in v))): node = TraceNode(TraceFunction("output")) add_output_node(node, v) log.debug('trace: apply register_submodule_tracer') module.apply(register_submodule_tracer) log.debug('trace: add hooks') hooks.append(module.register_forward_pre_hook(_model_pre_tracer)) hooks.append(module.register_forward_hook(_model_tracer)) yield module for hook in hooks: hook.remove() @contextlib.contextmanager def tracer_context(): """Basic context manager for tracing""" yield True lock(False) module_constructor_traced.clear() module_constructor_lines.clear() module_constructor_weakrefs.clear() original_values_for_tracked_objects.clear() skip_modules.clear() @contextlib.contextmanager def model_constructor_tracer(): """Basic context manager for capturing constructors for `nn.Module`""" with patch_helper(wrap_funcs=False, wrap_creation_funcs=False, wrap_tracking_modules=False): yield True @contextlib.contextmanager def ignore_mod_param_update_warning(): if not mod_param_update_warning_ignore(): mod_param_update_warning_ignore(True) yield True mod_param_update_warning_ignore(False) else: yield False @contextlib.contextmanager def model_tracer(): """Simple context manager for tracing. Also captures module constructors""" with tracer_context(): with model_constructor_tracer(): yield True @contextlib.contextmanager def construct_trace_graph( module, dummy_input: torch.Tensor, eliminate_dead_graph: bool, patch_torch_size: bool ) -> 'TraceGraph': """Simple context manager for creating a new TraceGraph""" current_graph(TraceGraph(module, dummy_input, eliminate_dead_graph, patch_torch_size)) yield current_graph.get_value() current_graph(None) @contextlib.contextmanager def override_current_trace_graph(new_graph: 'TraceGraph') -> 'TraceGraph': """Simple context manager for creating a new TraceGraph""" old_graph = current_graph.get_value() current_graph(new_graph) yield current_graph.get_value() current_graph(old_graph) @contextlib.contextmanager def import_patcher(): with tracer_context(): with patch_helper(wrap_creation_funcs=False, wrap_funcs=True, wrap_modules=False, wrap_tracking_modules=False): with no_catch(): yield True class TraceGraph(object): """A data structure for storing a computation graph""" global_functions: typing.Dict[str, int] global_nodes: typing.Dict[str, int] module_unique_name_dict: typing.Dict[int, torch.nn.Module] module_original_name_dict: typing.Dict[int, str] traced_modules: typing.List[str] input_nodes: typing.List[TraceNode] forward_nodes: typing.List[TraceNode] output_nodes: typing.List[TraceNode] constant_nodes: typing.List[TraceNode] other_init_nodes: typing.List[TraceNode] nodes_map: typing.Dict[str, TraceNode] tensor_pre_node_dict: typing.Dict[int, str] tensor_pre_index_dict: typing.Dict[int, int] module: torch.nn.Module dummy_input: torch.Tensor eliminate_dead_graph: bool inited: bool quantized: bool code: str def __init__( self, module: torch.nn.Module, dummy_input: torch.Tensor, eliminate_dead_graph: bool = False, patch_torch_size: bool = False, ): # Used for function / node numbering self.global_functions = {} self.global_nodes = {} # Unique name for modules and submodules self.module_unique_name_dict = {} self.module_original_name_dict = {} # Unique name for buffers and parameters self.parameter_original_name_dict = {} self.parameter_unique_name_dict = {} # Mapping between parameters and their parent modules self.parameter_module_dict = {} # Recording traced modules self.traced_modules = [] self.related_modules = [] # Recording nodes self.input_nodes = [] self.forward_nodes = [] self.output_nodes = [] self.constant_nodes = [] self.other_init_nodes = [] # Node <-> name mapping self.nodes_map = {} # Recording the previous node of the tensors self.tensor_pre_node_dict = {} # Recording the previous index of the tensors self.tensor_pre_index_dict = {} # Recording the tensor object of the parameters self.tensor_parameter_dict = {} # Input module if isinstance(module, DataParallel) or isinstance(module, DistributedDataParallel): log.error( 'You are tracing a parallel module, which is unsupported. Please pass in a raw model using `.module`.' ) assert False else: self.module = module # Input data self.dummy_input = dummy_input # Whether to keep inactive nodes after tracing self.eliminate_dead_graph = eliminate_dead_graph # Let's give the module and its children a name. self.__tag_nodes() # Whether the tracing is completed or not self.inited = False # Whether the graph is rewrited to be a quantized one self.quantized = False # Generated code self.code = "None" # Used namespaces self.used_namespaces = set() # Tracer options self.patch_torch_size = patch_torch_size def all_nodes(self) -> typing.List[TraceNode]: """Returns all the nodes in a computation graph during forward process""" return self.input_nodes + self.forward_nodes + self.output_nodes + self.constant_nodes def all_tensors(self) -> typing.List[torch.Tensor]: tensors = dict() for n in self.all_nodes(): for t in n.prev_tensors + n.next_tensors: if isinstance(t, torch.Tensor): tensors[id(t)] = t return list(tensors.values()) def __tag_nodes(self) -> None: """Gives the modules and the submodules a unique name""" # Tag submodules for n, m in self.module.named_modules(): n = n.replace(".", "_") n = n.replace("-", "_") n = n.replace("/", "_") n = n.replace(";", "_") n = 'module_' + n if n.isnumeric() else n self.module_unique_name_dict[id(m)] = n # Tag the module itself self.module_unique_name_dict[id(self.module)] = type(self.module).__name__ for n, p in self.module.named_parameters(): n = n.replace(".", "_") n = n.replace("-", "_") n = n.replace("/", "_") n = n.replace(";", "_") n = 'param_' + n if n.isnumeric() else n self.parameter_unique_name_dict[id(p)] = n for n, b in self.module.named_buffers(): n = n.replace(".", "_") n = n.replace("-", "_") n = n.replace("/", "_") n = n.replace(";", "_") n = 'buffer_' + n if n.isnumeric() else n self.parameter_unique_name_dict[id(b)] = n q = queue.Queue() q.put((self.module, '')) while not q.empty(): m, p = q.get() if isinstance(m, nn.Module): self.module_original_name_dict[id(m)] = p for n, c in m.named_children(): if isinstance(m, (nn.Sequential, nn.ModuleList)) and n.isnumeric(): c_p = f'{p}[{n}]' elif isinstance(m, nn.ModuleDict): c_p = f'{p}["{n}"]' else: if len(p) > 0: if '.' in n or '-' in n: c_p = f'{p}.get_submodule("{n}")' else: c_p = f'{p}.{n}' else: c_p = n q.put((c, c_p)) for n, c in m.named_parameters(recurse=False): self.parameter_module_dict[id(c)] = id(m) if len(p) > 0: if '.' in n or '-' in n: c_p = f'{p}.get_parameter("{n}")' else: c_p = f'{p}.{n}' else: c_p = n q.put((c, c_p)) for n, c in m.named_buffers(recurse=False): self.parameter_module_dict[id(c)] = id(m) if len(p) > 0: if '.' in n or '-' in n: c_p = f'{p}.get_buffer("{n}")' else: c_p = f'{p}.{n}' else: c_p = n q.put((c, c_p)) else: self.parameter_original_name_dict[id(m)] = p def __active_detection(self, node: TraceNode): """Detects whether the node is active or not""" q = queue.Queue() q.put(node) while not q.empty(): node = q.get() if not node.active: node.active = True for i in node.prev_nodes: q.put(i) def init(self) -> None: """Builds a computation graph""" if self.inited: return with no_catch(): device = get_module_device(self.module) with self.__numbering_context(): if type(self.dummy_input) is torch.Tensor: actual_input = [self.dummy_input] elif isinstance(self.dummy_input, (tuple, list)): actual_input = list(self.dummy_input) else: log.error(f'Unsupported type {type(self.dummy_input)} for dummy input') assert False for i in range(len(actual_input)): dummy_input = actual_input[i] if type(dummy_input) is torch.Tensor: new_input = dummy_input.detach().clone() if new_input.is_floating_point(): new_input.requires_grad = True if new_input.device != device: new_input = new_input.to(device=device) actual_input[i] = new_input original_state_dict = copy.deepcopy(self.module.state_dict()) with patch_helper(wrap_tracking_modules=self.patch_torch_size): with hook_modules(self.module): self.module(*actual_input) self.module.load_state_dict(original_state_dict) if self.eliminate_dead_graph: self.eliminate_dead_graph_pass() self.recompute_forward_order() self.inited = True def eliminate_dead_graph_pass(self): for n in self.input_nodes + self.forward_nodes + self.output_nodes: n.active = False for i in self.output_nodes: self.__active_detection(i) active_input_nodes = [i for i in self.input_nodes if i.active] active_forward_nodes = [i for i in self.forward_nodes if i.active] active_constant_nodes = dict() for n in self.forward_nodes + self.output_nodes: if n.active: for pn in n.prev_nodes: if isinstance(pn.module, ConstantNode): active_constant_nodes[pn.unique_name] = pn self.input_nodes = active_input_nodes self.forward_nodes = active_forward_nodes self.constant_nodes = list(active_constant_nodes.values()) def add_inputs_for_tensors(self, input_tensors): """Add input nodes for specific tensors""" with self.__numbering_context(): self.global_functions['input'] = len(self.input_nodes) - 1 for i, t in enumerate(input_tensors): node = self.nodes_map[self.tensor_pre_node_dict[id(t)]] with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("input")) add_input_node(new_node, t) node.prev_nodes.append(new_node) self.recompute_forward_order() def add_inputs_for_tensors_in_node(self, input_node, input_tensors): """Add input nodes for specific tensors with a node""" with self.__numbering_context(): self.global_functions['input'] = len(self.input_nodes) - 1 for i, t in enumerate(input_tensors): input_node.prev_tensors.append(t) input_node.prev_indices.append(None) with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("input")) add_input_node(new_node, t) input_node.prev_nodes.append(new_node) self.recompute_forward_order() def add_outputs_for_tensors(self, output_tensors): """Add output nodes for specific tensors""" with self.__numbering_context(): self.global_functions['output'] = len(self.output_nodes) - 1 for i, t in enumerate(output_tensors): node = self.nodes_map[self.tensor_pre_node_dict[id(t)]] with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("output")) add_output_node(new_node, t) node.next_nodes.append(new_node) self.recompute_forward_order() def add_state_input_outputs(self): for node in self.forward_nodes: if not isinstance(node.module, (nn.LSTM, nn.GRU, nn.RNN)): continue if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0: continue input_tensor = node.prev_tensors[0] max_batch_size = input_tensor.size(0) if node.module.batch_first else input_tensor.size(1) num_directions = 2 if node.module.bidirectional else 1 real_hidden_size = ( node.module.proj_size if getattr(node.module, 'proj_size', 0) > 0 else node.module.hidden_size ) hidden_shape = (node.module.num_layers * num_directions, max_batch_size, real_hidden_size) if len(node.prev_nodes) == 1: hx = torch.zeros( hidden_shape, dtype=input_tensor.dtype, device=input_tensor.device, ) if isinstance(node.module, nn.LSTM): cx = torch.zeros( hidden_shape, dtype=input_tensor.dtype, device=input_tensor.device, ) self.add_inputs_for_tensors_in_node(node, [hx, cx]) indices = (-2, -1) else: self.add_inputs_for_tensors_in_node(node, [hx]) indices = (-1,) for i in indices: dtype_str = str(self.input_nodes[i].next_tensors[0].dtype).replace('torch.', '') shape_str = str([x for x in self.input_nodes[i].next_tensors[0].shape]) log.warning(f'Input added: {self.input_nodes[i].unique_name}: {dtype_str}{shape_str}') if len(node.next_nodes) == 1: if isinstance(node.module, nn.LSTM): self.add_outputs_for_tensors(node.next_tensors[1]) indices = (-2, -1) else: self.add_outputs_for_tensors([node.next_tensors[1]]) indices = (-1,) for i in indices: dtype_str = str(self.output_nodes[i].prev_tensors[0].dtype).replace('torch.', '') shape_str = str([x for x in self.output_nodes[i].prev_tensors[0].shape]) log.warning(f'Output added: {self.output_nodes[i].unique_name}: {dtype_str}{shape_str}') def reset_input_output_for_graph(self, input_names: typing.List[str], output_names: typing.List[str]): """Extract a subgraph from the original computation graph""" assert len(set(input_names) & set(output_names)) == 0, "A node cannot be both input and output" input_nodes = [] output_nodes = [] for name in input_names + output_names: node = self.nodes_map.get(name, None) assert node is not None, f"{name} is not a node in TraceGraph" assert node in self.forward_nodes, f"{name} is not a forward node in TraceGraph" if name in input_names: input_nodes.append(node) else: output_nodes.append(node) self.assert_is_subgraph(input_nodes, output_nodes) for node in self.input_nodes + self.output_nodes: del self.nodes_map[node.unique_name] self.input_nodes.clear() self.output_nodes.clear() with self.__numbering_context(): for node in input_nodes: node.prev_nodes.clear() for i, t in enumerate(node.prev_tensors): with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("input")) add_input_node(new_node, t) node.prev_nodes.append(new_node) for node in output_nodes: node.next_nodes.clear() for i, t in enumerate(node.next_tensors): if type(t) is torch.Tensor or ( isinstance(t, (list, tuple)) and all((type(x) is torch.Tensor for x in t)) ): with override_current_trace_graph(self): new_node = TraceNode(TraceFunction("output")) add_output_node(new_node, t) node.next_nodes.append(new_node) else: log.warning( "Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or" " tuple" ) self.eliminate_dead_graph_pass() self.recompute_forward_order() def assert_is_subgraph(self, input_nodes, output_nodes): q = queue.Queue() for node in output_nodes: q.put(node) start_nodes = {node.unique_name: node for node in input_nodes} visited = set() actual_starts = dict() while not q.empty(): node = q.get() if node.unique_name in visited: continue visited.add(node.unique_name) if node.unique_name in start_nodes: continue if len(node.prev_nodes) == 0: if not len(node.next_nodes) == 0: actual_starts[node.unique_name] = node continue for prev_node in node.prev_nodes: q.put(prev_node) missing_inputs = set(actual_starts) - set(start_nodes) assert len(missing_inputs) == 0, f"Not a subgraph, missing inputs: {missing_inputs}" @contextlib.contextmanager def __numbering_context(self): """A simple context manager for numbering nodes""" yield True self.global_functions.clear() self.global_nodes.clear() def __gen_init_code(self) -> str: """Generates the code for the init function for a `nn.Module`""" generated_node = [] lines = [] mod_ids = [] mod_cache_dict = dict() for node in self.constant_nodes + self.forward_nodes + self.other_init_nodes: if node.unique_name in generated_node or id(node.module) in mod_ids: log.info(f"skip dumplicate node code gen {node.unique_name}") continue generated_node.append(node.unique_name) mod_ids.append(id(node.module)) if id(node.module) in module_constructor_lines: root_ns = qualified_name(node.type()).split('.')[0] if isinstance(node.module, TraceFunction) and ( node.module.full_name.startswith('torch.') or node.module.full_name.startswith('self.') or '.' not in node.module.full_name ): continue self.used_namespaces.add(root_ns) orig_constructor_line = module_constructor_lines[id(node.module)] line = f' self.{node.unique_name} = {orig_constructor_line}' lines.append(line) elif type(node.module) is ConstantNode: # Parameter generation self.used_namespaces.add('torch') requires_grad_prop = '' if node.module.is_parameter != node.module.requires_grad: requires_grad_prop = f', requires_grad={node.module.requires_grad}' if node.module.is_parameter: line = ( f' self.register_parameter("{node.unique_name}",' f' torch.nn.Parameter(torch.empty({node.module.shape}, dtype={node.module.dtype})' f'{requires_grad_prop}))' ) elif node.module.is_persistent: line = ( f' self.register_buffer("{node.unique_name}", torch.empty({node.module.shape},' f' dtype={node.module.dtype}{requires_grad_prop}))' ) else: line = ( f' self.register_buffer("{node.unique_name}", torch.tensor({node.module.data_str},' f' dtype={node.module.dtype}{requires_grad_prop}), persistent=False)' ) lines.append(line) elif type(node.module) is not TraceFunction: # Generate the module even if the constructor is not caught log.info( f'the constructor of the module {node.unique_name} of type {type(node.module).__name__} is not' ' traced, trying the experimental way' ) root_ns = qualified_name(node.type()).split('.')[0] self.used_namespaces.add(root_ns) orig_constructor_line, mod_cache = gen_module_constrctor_line(node.module, mod_cache_dict) line = f' self.{node.unique_name} = {orig_constructor_line}' lines.append(line) mod_cache_dict.update(mod_cache) block = "\n".join(lines) return block def __gen_forward_code(self, inplace=False) -> str: """Generates the code for the forward function for a `nn.Module`""" lines = [f" def forward(self, {','.join([i.unique_name for i in self.input_nodes])}):"] mod_name_dict = {} for node in self.forward_nodes: output = ", ".join([node.unique_name]) param = ", ".join([node.prev_node_unique_name(i, inplace) for i in range(len(node.prev_nodes))]) if type(node.module) is TraceFunction: full_name = node.full_name() if not full_name.startswith('torch.') and not full_name.startswith('self.') and '.' in full_name: ns = '.'.join(full_name.split('.')[:-1]) self.used_namespaces.add(ns) first_arg = None if node.is_class(): first_arg = node.prev_node_unique_name(0, inplace) if node.type().startswith('__i') and node.type().endswith('__'): inner_op = node.module.func_type[3:-2] if inner_op in SPECIAL_OPERATORS: node.module.func_type = f'__{inner_op}__' parts = node.module.full_name.split('.')[:-1] + [node.module.func_type] node.module.full_name = '.'.join(parts) if first_arg is not None: alias = first_arg else: alias = node.module.get_tensor_name(0, inplace) node.module.add_alias(alias) aliases = node.module.get_aliases() prefix = '' if aliases is not None: prefix = ''.join([f'{x} = ' for x in aliases]) line = f" {prefix}{output} = {node.module.extra_expr(first=first_arg, original=inplace)}" else: if inplace: mod_name = node.original_name else: mod_name_dict.setdefault(node.module, node.unique_name) mod_name = mod_name_dict[node.module] if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0: continue if node.type() is nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3: first_arg = node.prev_node_unique_name(0) param = ", ".join([node.prev_node_unique_name(i) for i in range(1, len(node.prev_nodes))]) line = f" {output} = self.{mod_name}({first_arg}, ({param}))" else: line = f" {output} = self.{mod_name}({param})" lines.append(line) for pn in {pn.unique_name: pn for pn in node.prev_nodes}.values(): if node.forward_order == max([n.forward_order for n in pn.next_nodes]): if pn.type() not in (ConstantNode, torch.nn.quantized.FloatFunctional): lines.append(f" {pn.unique_name} = None") def _gen_output_node(node): if node.rev_index: return f'[{", ".join([node.prev_node_unique_name(i) for i in range(len(node.prev_nodes))])}]' else: return node.prev_node_unique_name(0) lines.append(f" return {', '.join([_gen_output_node(i) for i in self.output_nodes])}") block = "\n".join(lines) return block def __gen_import_code(self) -> str: """Generates the code for the import section for a `nn.Module`""" # TODO: Selective module importing import_block = """import torch\nimport torch.nn\nimport torch.functional\nimport torch.nn.functional""" if self.quantized is True: import_block += '\nimport torch.quantization\nimport torch.nn.quantized' additional_ns = sorted(self.used_namespaces - set(['torch'])) import_block += ''.join([f'\nimport {ns}' for ns in additional_ns]) return import_block def __gen_input_code(self) -> str: """Generates the code for the input section for the code to invoke `nn.Module`""" input_block = "" for i, node in enumerate(self.input_nodes): shape = ", ".join((str(i) for i in node.next_tensors[0].shape)) dtype = node.next_tensors[0].dtype if i != 0: input_block += "\n" input_block += f" dummy_input_{i} = torch.ones(({shape}), dtype={dtype})" return input_block def recompute_forward_order(self): forward_order = 0 for n in self.input_nodes + self.forward_nodes + self.output_nodes: n.forward_order = forward_order forward_order += 1 for prev in n.prev_nodes: if n not in prev.next_nodes: prev.next_nodes.append(n) def generate_code( self, output_script_path: typing.Optional[str], output_weight_path: typing.Optional[str], model_name: str = 'DefaultModel', check: bool = False, ) -> bool: """The main function for code generation""" output_paths = (output_script_path, output_weight_path) for output_path in output_paths: if not output_path: continue output_dir = os.path.dirname(output_path) if output_dir != '' and not os.path.exists(output_dir): os.makedirs(output_dir) DummyModel = type('DummyModel', (torch.nn.Module,), {}) dummy_model = DummyModel() mod_ids = set() for node in self.forward_nodes: if id(node.module) not in mod_ids: setattr(dummy_model, node.unique_name, node.module) mod_ids.add(id(node.module)) for node in self.constant_nodes: if node.module.is_parameter or node.module.is_persistent: dtype = getattr(torch, node.module.dtype.split('.')[-1]) new_tensor = torch.tensor(node.module.data, dtype=dtype) if node.module.is_parameter: weight = torch.nn.Parameter(new_tensor) dummy_model.register_parameter(node.unique_name, weight) else: dummy_model.register_buffer(node.unique_name, new_tensor) if output_weight_path: torch.save(dummy_model.state_dict(), output_weight_path) output_weight_path_str = output_weight_path.replace('\\', '\\\\') init_block = self.__gen_init_code() forward_block = self.__gen_forward_code() import_block = self.__gen_import_code() input_block = self.__gen_input_code() context = { "import_block": import_block, "init_block": init_block, "forward_block": forward_block, "name_block": model_name, "load_weight_block": "" if output_weight_path is None else f" model.load_state_dict(torch.load('{output_weight_path_str}'))", "input_block": input_block, "input_names": ", ".join([f"dummy_input_{i}" for i in range(len(self.input_nodes))]), } code = MODULE_TEMPLATE % context if output_script_path: if os.path.exists(output_script_path): os.remove(output_script_path) with io.open(output_script_path, 'w') as f: f.write(code) if check: training_state = self.module.training valid = True with torch.no_grad(): self.module.eval() original_outputs = tensors2ndarray(self.module(*self.dummy_input)) new_module = import_from_path(f'tracer_check.{model_name}', output_script_path, model_name)() new_module.eval() if output_weight_path: new_module.load_state_dict(torch.load(output_weight_path)) new_outputs = tensors2ndarray(new_module(*self.dummy_input)) for i in range(len(original_outputs)): output = original_outputs[i] new_output = new_outputs[i] if not np.allclose(output, new_output): log.warning(f"[WARNING] Output {i} is not equal.") valid = False if training_state: self.module.train() return valid else: return True def inplace_commit(self, show_code: bool = False): """Commit the changes to the TraceGraph and applies it to the model""" forward_block = self.__gen_forward_code(True) import_block = self.__gen_import_code() code = re.sub(r'^ ', '', forward_block, flags=re.MULTILINE) if show_code: log.warning(f'The new forward function for the model:\n{code}') tmp_ns = {} exec(import_block, tmp_ns) exec(code, tmp_ns) new_func = tmp_ns['forward'] setattr(self.module, 'forward', types.MethodType(new_func, self.module)) for node in self.forward_nodes + self.other_init_nodes: if isinstance(node.module, nn.Module) and node.unique_name not in self.related_modules: setattr(self.module, node.original_name, node.module) for node in self.constant_nodes: if not node.module.inplace: setattr(self.module, node.original_name, node.next_tensors[0]) return self.module def update_submodule_in_nodes_from_predicate( self, nodes: typing.List[TraceNode], module_gen_predicate: typing.Callable[[nn.Module], nn.Module], inplace: bool = False, ): """update a submodule from the nodes using the predicate given""" for node in nodes: module = node.module new_module = module_gen_predicate(module) self.update_submodule_in_node(node, new_module, inplace) def get_submodule_with_parent_from_name(self, module_name: str, inplace: bool = False): """Gets the submodule with its parent using the name given""" if inplace: module_name = re.sub('get_submodule\\("(.*?)"\\)', '\\1', module_name) module_name = re.sub('\\[("|)(.*?)("|)\\]', '.\\2', module_name) module_name_parts = module_name.split('.') cur_obj = self.module last_obj = None for ns in module_name_parts: last_obj = cur_obj if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) return cur_obj, last_obj def update_submodule_in_node(self, node: TraceNode, module: nn.Module, inplace: bool = False): """update a submodule from the nodes using the module given""" module_name = self.module_original_name_dict[id(node.module)] if inplace: module_name = re.sub('get_submodule\\("(.*?)"\\)', '\\1', module_name) module_name = re.sub('\\[("|)(.*?)("|)\\]', '.\\2', module_name) module_name_parts = module_name.split('.') cur_obj = self.module for ns in module_name_parts[:-1]: if type(cur_obj) is nn.ModuleList: cur_obj = cur_obj[int(ns)] elif type(cur_obj) is nn.ModuleDict: cur_obj = cur_obj[ns] else: cur_obj = getattr(cur_obj, ns) ns = module_name_parts[-1] new_obj = module if type(cur_obj) is nn.ModuleList: cur_obj[int(ns)] = new_obj elif type(cur_obj) is nn.ModuleDict: cur_obj[ns] = new_obj else: setattr(cur_obj, ns, new_obj) def filter_forward_nodes(self, predicate, custom_data=None, reverse=False) -> typing.List[TraceNode]: """A utility function for filtering forward nodes""" nodes = [] iter_nodes = self.forward_nodes if reverse: iter_nodes = reversed(iter_nodes) for node in iter_nodes: if predicate(node, custom_data): nodes.append(node) return nodes def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None): """Insert a module or an existing node after a node in the computation graph""" # Create a new node and connects it to the next node/tensors if type(module) is not TraceNode: new_node = TraceNode(module, cur_graph=self) if node in self.input_nodes or node in self.constant_nodes: self.forward_nodes.insert(0, new_node) elif node in self.output_nodes: log.error('You cannot insert a node after output nodes') assert False else: idx = self.forward_nodes.index(node) self.forward_nodes.insert(idx + 1, new_node) self.nodes_map[new_node.unique_name] = new_node else: new_node = module is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) new_node.prev_nodes.append(node) new_node.next_nodes.extend(node.next_nodes) if next_tensors is None: next_tensors = [None] * len(node.next_tensors) for t, new_t in zip(node.next_tensors, next_tensors): if new_t is None: new_t = t.clone() self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name new_node.prev_tensors.append(t) new_node.next_tensors.append(new_t) new_node.prev_indices.append(None) # Make input/constant nodes connects to the new node node.next_nodes.clear() node.next_nodes.append(new_node) # Connect the next nodes to the new node tensor_replace_dict = dict(zip(new_node.prev_tensors, new_node.next_tensors)) for next_node in new_node.next_nodes: is_next_constant_node = type(next_node) in (ConstantNode, torch.nn.quantized.FloatFunctional) for i, n in enumerate(next_node.prev_nodes): if n == node: next_node.prev_nodes[i] = new_node # Make sure the data is writable if isinstance(next_node.prev_tensors, tuple): next_node.prev_tensors = list(next_node.prev_tensors) updated_indices = [] for i, t in enumerate(next_node.prev_tensors): if t in tensor_replace_dict: next_node.prev_tensors[i] = tensor_replace_dict[t] updated_indices.append(next_node.prev_indices[i]) # Since the function calls are rendered beforehand, # we need to change them as well. if type(next_node.module) is TraceFunction: if next_node.module.args_string is not None: for idx in updated_indices: old_unique_name = tensor_name_from_parts(node.unique_name, idx, is_constant_node) new_unique_name = tensor_name_from_parts(new_node.unique_name, idx, is_next_constant_node) next_node.module.replace_tensor_name(old_unique_name, new_unique_name) next_node.module.update_args_string() def insert_new_after( self, node, module_or_func, prev_tensors: typing.List[torch.Tensor], prev_indices: typing.List[torch.Tensor], next_tensors: typing.Optional[typing.List[torch.Tensor]] = None, before_node: typing.Optional[TraceNode] = None, ): assert type(module_or_func) is not TraceNode new_node = TraceNode(module_or_func, cur_graph=self) if next_tensors is None: next_tensors = [t.clone() for t in prev_tensors] for new_t, new_i in zip(next_tensors, prev_indices): self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name if new_i is not None: self.tensor_pre_index_dict[id(new_t)] = new_i new_node.prev_tensors.extend(prev_tensors) new_node.next_tensors.extend(next_tensors) new_node.prev_indices.extend(prev_indices) new_node.prev_nodes.append(node) node.next_nodes.append(new_node) if before_node is not None: idx = self.forward_nodes.index(before_node) self.forward_nodes.insert(idx, new_node) else: self.forward_nodes.append(new_node) self.nodes_map[new_node.unique_name] = new_node return new_node def insert_between( self, prev_node: TraceNode, next_node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None, move_idx: bool = False, tensor_ptrs: typing.Optional[typing.Set[int]] = None, ): """Insert a module or an existing node between two nodes in the computation graph""" # Create a new node and connects it to the previous node/tensors old_unique_name = prev_node.unique_name is_constant_node = type(prev_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) if type(module) is not TraceNode: new_node = TraceNode(module, cur_graph=self) if prev_node not in next_node.prev_nodes or next_node not in prev_node.next_nodes: log.error('You cannot insert a node between two nodes that is not connected') assert False idx = self.forward_nodes.index(next_node) self.forward_nodes.insert(idx, new_node) self.nodes_map[new_node.unique_name] = new_node new_node.prev_nodes.append(prev_node) new_node.next_nodes.append(next_node) else: new_node = module if prev_node not in new_node.prev_nodes: new_node.prev_nodes.append(prev_node) if next_node not in new_node.next_nodes: new_node.next_nodes.append(next_node) if next_tensors is None: next_tensors = list(new_node.next_tensors) new_node.prev_tensors.clear() new_node.next_tensors.clear() new_node.prev_indices.clear() is_new_constant_node = type(new_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) # Gather tensors from previous nodes prev_tensors = [] prev_indices = [] remaining_count = 0 for pt, pidx in zip(next_node.prev_tensors, next_node.prev_indices): skip_node = tensor_ptrs is not None and id(pt) not in tensor_ptrs for nt in prev_node.next_tensors: if id(pt) == id(nt): if not skip_node: prev_tensors.append(pt) prev_indices.append(pidx) else: remaining_count += 1 break elif isinstance(nt, (list, tuple)): for i, ntt in enumerate(nt): if id(ntt) == id(pt): if not skip_node: prev_tensors.append(pt) prev_indices.append(pidx) else: remaining_count += 1 break if next_tensors is None: next_tensors = [None] * len(prev_tensors) assert len(next_tensors) > 0, f'{prev_node.unique_name} and {next_node.unique_name} has no common tensors' for idx, (t, new_t, pidx) in enumerate(zip(prev_tensors, next_tensors, prev_indices)): if new_t is None: new_t = t.clone() next_tensors[idx] = new_t self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name new_node.prev_tensors.append(t) new_node.next_tensors.append(new_t) new_node.prev_indices.append(pidx if move_idx else None) # Make output nodes connects to the new node for idx in range(len(next_node.prev_nodes) - remaining_count): if next_node.prev_nodes[idx] == prev_node: next_node.prev_nodes[idx] = new_node # Update tensors in output nodes index_mapping = [] for idx, t in enumerate(next_node.prev_tensors): for pt, nt in zip(prev_tensors, next_tensors): if id(t) == id(pt): next_node.prev_tensors[idx] = nt old_index = next_node.prev_indices[idx] if move_idx: next_node.prev_indices[idx] = None new_index = next_node.prev_indices[idx] index_mapping.append((old_index, new_index)) break # Connect the previous nodes to the new node for prev_node in new_node.prev_nodes: if remaining_count > 0: if new_node not in prev_node.next_nodes: prev_node.next_tensors.append(new_node) else: for i, n in enumerate(prev_node.next_nodes): if n == next_node: prev_node.next_nodes[i] = new_node break # Update previous node name for next nodes (TraceFunction) if type(next_node.module) is TraceFunction: for old_idx, new_idx in index_mapping: prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node) next_unique_name = tensor_name_from_parts(new_node.unique_name, new_idx, is_new_constant_node) log.debug(f'tensor rename: {prev_unique_name} -> {next_unique_name}') next_node.module.replace_tensor_name(prev_unique_name, next_unique_name) next_node.module.update_args_string() def insert_before( self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None, move_idx: bool = False, next_indices: typing.Optional[typing.List[int]] = None, ): """Insert a module or an existing node before a node in the computation graph""" # Create a new node and connects it to the previous node/tensors if type(module) is not TraceNode: if not isinstance(module, (tuple, list)): modules = [module] rev_mode = False else: if not node.rev_index: log.error('You can only insert nodes with a list modules when node.rev_index=True') assert False if len(module) != len(node.prev_nodes): log.error(f'The number of the modules provided is wrong, expected: {len(node.prev_nodes)}') assert False modules = module rev_mode = True new_nodes: typing.List[TraceNode] = [] for module in modules: new_node = TraceNode(module, cur_graph=self) new_nodes.append(new_node) else: new_nodes = [module] if node in self.input_nodes or node in self.constant_nodes: log.error('You cannot insert a node before input/constant nodes') assert False elif node in self.output_nodes: for new_node in new_nodes: self.forward_nodes.append(new_node) else: for new_node in new_nodes: idx = self.forward_nodes.index(node) self.forward_nodes.insert(idx, new_node) for idx, new_node in enumerate(new_nodes): self.nodes_map[new_node.unique_name] = new_node if not rev_mode: new_node.prev_nodes.extend(node.prev_nodes) else: new_node.prev_nodes.append(node.prev_nodes[idx]) new_node.next_nodes.append(node) index_mapping = [] for idx, new_node in enumerate(new_nodes): if not rev_mode: prev_tensors = node.prev_tensors prev_indices = node.prev_indices if move_idx: node.prev_indices = [None] * len(node.prev_indices) else: prev_tensors = node.prev_tensors[idx : idx + 1] prev_indices = node.prev_indices[idx : idx + 1] if move_idx: node.prev_indices[idx] = None if next_tensors is None: next_tensors = [None] * len(prev_tensors) if next_indices is None: next_indices = [None] * len(prev_tensors) for t, new_t, ind, n_ind in zip(prev_tensors, next_tensors, prev_indices, next_indices): if new_t is None: new_t = t.clone() self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name new_node.prev_tensors.append(t) new_node.next_tensors.append(new_t) new_node.prev_indices.append(ind if move_idx else n_ind) index_mapping.append((ind, new_node.prev_indices[-1])) # Make output nodes connects to the new node node.prev_nodes.clear() node.prev_nodes.extend(new_nodes) node.prev_tensors.clear() for new_node in new_nodes: node.prev_tensors.extend(new_node.next_tensors) # Connect the previous nodes to the new node if rev_mode: for prev_node in new_node.prev_nodes: idx = None for i, n in enumerate(prev_node.next_nodes): if n == node: idx = i break if idx is not None: for i, new_node in enumerate(new_nodes): prev_node.next_nodes.insert(idx + 1 + i, new_node) prev_node.next_nodes.pop(idx) else: for new_node in new_nodes: for prev_node in new_node.prev_nodes: for i, n in enumerate(prev_node.next_nodes): if n == node: prev_node.next_nodes[i] = new_node break # Update previous node name for next nodes (TraceFunction) if type(node.module) is TraceFunction and node not in self.output_nodes: new_node = new_nodes[0] for i in range(len(new_node.prev_nodes)): old_unique_name = new_node.prev_nodes[i].unique_name is_constant_node = type(new_node.prev_nodes[i].module) in ( ConstantNode, torch.nn.quantized.FloatFunctional, ) is_next_constant_node = type(new_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) old_idx, new_idx = index_mapping[i] prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node=is_constant_node) next_unique_name = tensor_name_from_parts( new_node.unique_name, new_idx, is_constant_node=is_next_constant_node ) log.debug(f'node rename: {prev_unique_name} -> {next_unique_name}') node.module.replace_tensor_name(prev_unique_name, next_unique_name) node.module.update_args_string() def __call__(self, *args, **kwargs): """Calls the function with a list of tensor inputs""" # Prepare inputs tensors_map = {} i = 0 for t in args: if isinstance(t, (tuple, list)): for rt in t: tensors_map[(f'input_{i}_f', None)] = rt i += 1 else: tensors_map[(f'input_{i}_f', None)] = t i += 1 # Sort nodes sorted_nodes = sorted(self.forward_nodes, key=lambda x: x.forward_order) for node in sorted_nodes: # Basic info op_kind = node.kind() op_kind = op_kind.__name__ if isinstance(op_kind, type) else op_kind log.debug(f'{node.unique_name}({op_kind}):') log.debug(' Inputs:') # TODO: Support the following case assert len(node.prev_nodes) == len(node.prev_tensors), "not supported" # Collect input tensors prev_tensors = [] for i, pn in enumerate(node.prev_nodes): pn_name = pn.unique_name pn_idx = node.prev_indices[i] log.debug(f' {pn_name} {pn_idx}') if isinstance(pn.module, ConstantNode): prev_tensors.append(node.prev_tensors[i]) elif (pn_name, pn_idx) in tensors_map: prev_tensors.append(tensors_map[(pn_name, pn_idx)]) else: assert False, f"({pn_name}, {pn_idx}) is not found in tensor dict" node.prev_tensors = prev_tensors # Calculate output if isinstance(node.module, nn.Module): output = node.module(*prev_tensors) else: output = node.module(prev_tensors) # Shape handling if node.type() in ('shape', 'size'): if len(prev_tensors) == 1: output = torch.tensor(output).unbind(0) else: output = torch.tensor(output) log.debug('') log.debug(' Outputs:') # Updating output tensors if isinstance(output, (list, tuple)): for j, t in enumerate(output): tensors_map[(node.unique_name, j)] = t node.next_tensors[j] = t else: tensors_map[(node.unique_name, None)] = output node.next_tensors[0] = output for i in range(len(node.next_tensors)): log.debug(f' {i}') log.debug('') def replace_node_module(self, node: TraceNode, module: torch.nn.Module) -> None: """Replaces a module in a node with another""" # Update unique name for node old_unique_name = node.unique_name is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) is_new_constant_node = type(module) in (ConstantNode, torch.nn.quantized.FloatFunctional) node.unique_name = self.module_unique_name_dict[id(module)] # Drop original module constructor line if id(node.module) in module_constructor_lines: if id(node.module) in module_constructor_traced: module_constructor_traced.remove(id(node.module)) del module_constructor_lines[id(node.module)] # Update module for node node.module = module # Update node map del self.nodes_map[old_unique_name] self.nodes_map[node.unique_name] = node # Update previous node name for next tensors for t in node.next_tensors: self.tensor_pre_node_dict[id(t)] = self.tensor_pre_node_dict[id(t)].replace( old_unique_name, node.unique_name ) # Update previous node name for next nodes (TraceFunction) for n in node.next_nodes: for i, pt in enumerate(n.prev_tensors): for nt in node.next_tensors: if id(nt) == id(pt): idx = n.prev_indices[i] if type(n.module) is TraceFunction: prev_unique_name = tensor_name_from_parts(old_unique_name, idx, is_constant_node) next_unique_name = tensor_name_from_parts(node.unique_name, idx, is_new_constant_node) log.debug(f'node rename: {prev_unique_name} -> {next_unique_name}') n.module.replace_tensor_name(prev_unique_name, next_unique_name) n.module.update_args_string() break def fuse_nodes_to_func( self, nodes: typing.List[TraceNode], full_name: str, kind: str, func_type: str, is_class: bool ) -> None: """Fuses several nodes into one function""" if len(nodes) > 1: # Set the full name if the first node is already a TraceFunction # Otherwise, we need to construct one. next_nodes = [] next_tensors = [] if type(nodes[0].module) is TraceFunction: next_nodes.extend(nodes[-1].next_nodes) next_tensors.extend(nodes[-1].next_tensors) last_node_unique_name = nodes[-1].unique_name first_node_unique_name = nodes[0].unique_name for node in nodes[1:]: name = node.unique_name del self.nodes_map[name] self.forward_nodes.remove(node) if id(node.module) in module_constructor_lines: if id(node.module) in module_constructor_traced: module_constructor_traced.remove(id(node.module)) del module_constructor_lines[id(node.module)] node = nodes[0] node.next_nodes.clear() node.next_tensors.clear() node.next_nodes.extend(next_nodes) node.next_tensors.extend(next_tensors) node.module.func_type = func_type node.module.is_class = is_class node.module.kind = kind node.module.full_name = full_name # Update next tensors for t in node.next_tensors: self.tensor_pre_node_dict[id(t)] = self.tensor_pre_node_dict[id(t)].replace( last_node_unique_name, first_node_unique_name ) for n in node.next_nodes: # Update next nodes for i, pn in enumerate(n.prev_nodes): if pn.unique_name == last_node_unique_name: n.prev_nodes[i] = node for i, pt in enumerate(n.prev_tensors): for nt in node.next_tensors: if id(pt) == id(nt): idx = n.prev_indices[i] # Rewrite func calls in next nodes if type(n.module) is TraceFunction: old_unique_name = tensor_name_from_parts(last_node_unique_name, idx, False) new_unique_name = tensor_name_from_parts(first_node_unique_name, idx, False) n.module.replace_tensor_name(old_unique_name, new_unique_name) n.module.update_args_string() break else: # TODO: Implement this codepath log.error('Module fusion requires the first node to be a TraceFunction.') raise NotImplementedError else: log.warning('Calling fuse with less than 2 nodes is no-op.') def remove_node(self, node: TraceNode) -> None: """Remove a node from the computation graph""" if node not in self.forward_nodes: log.error('Only forward nodes can be removed') assert False if len(node.prev_nodes) != 1: log.error('You cannot remove a node with multiple input nodes') assert False if len(node.prev_tensors) != len(node.next_tensors): log.error('You cannot remove a node in which the size of input tensors and the output tensors is different') assert False for idx, (prev_tensor, next_tensor) in enumerate(zip(node.prev_tensors, node.next_tensors)): if prev_tensor.shape != next_tensor.shape: log.error(f'The shape of the input/output at index {idx} mismatches') log.error(f'The shape of the input tensor is {prev_tensor.shape}') log.error(f'The shape of the output tensor is {next_tensor.shape}') assert False prev_node = node.prev_nodes[0] next_nodes = node.next_nodes tensor_dict = dict(zip(node.next_tensors, node.prev_tensors)) index_dict = dict(zip(node.next_tensors, node.prev_indices)) is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) is_prev_constant_node = type(prev_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional) old_unique_name = node.unique_name # Deal with previous nodes if node in prev_node.next_nodes: prev_node.next_nodes.remove(node) else: log.error('Current node is not in the next nodes of the previous node') assert False prev_node.next_nodes.extend(next_nodes) # Deal with next nodes for n in next_nodes: # Handle previous nodes for i, pn in enumerate(n.prev_nodes): if pn == node: n.prev_nodes[i] = prev_node # Handle previous tensors for i, pt in enumerate(n.prev_tensors): if pt in tensor_dict: n.prev_tensors[i] = tensor_dict[pt] old_idx = n.prev_indices[i] n.prev_indices[i] = index_dict[pt] new_idx = n.prev_indices[i] # Rewrite func calls in next nodes if type(n.module) is TraceFunction: if n.module.args_string is not None: prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node) new_unique_name = tensor_name_from_parts( prev_node.unique_name, new_idx, is_prev_constant_node ) n.module.replace_tensor_name(prev_unique_name, new_unique_name) n.module.update_args_string() # Remove this node self.forward_nodes.remove(node) del self.nodes_map[node.unique_name] if id(node.module) in module_constructor_lines: if id(node.module) in module_constructor_traced: module_constructor_traced.remove(id(node.module)) del module_constructor_lines[id(node.module)] def load_overridable_modules(): if overridable_modules_loaded(): modules = overridable_modules else: modules = fetch_modules() overridable_modules.extend(modules) overridable_modules_loaded(True) return modules def load_overridable_funcs(): if overridable_funcs_loaded(): funcs = overridable_funcs else: funcs = fetch_funcs() config = os.path.join(current_dir, 'configs/torch_quant_stub_func.yml') funcs.update(fetch_funcs(config)) overridable_funcs.update(funcs) overridable_funcs_loaded(True) return funcs def load_torch_overrides_funcs(funcs): if not torch_overrides_funcs_loaded(): o_funcs, o_wrappers = prepare_torch_overrides_funcs(funcs) torch_overrides_funcs.extend(o_funcs) torch_overrides_wrappers.extend(o_wrappers) torch_overrides_funcs_loaded(True) else: o_funcs, o_wrappers = torch_overrides_funcs, torch_overrides_wrappers return o_funcs, o_wrappers def load_creation_funcs(): if overridable_creation_funcs_loaded(): creation_funcs = overridable_creation_funcs else: creation_funcs = fetch_funcs(os.path.join(current_dir, 'configs/torch_creation_funcs_override.yml')) overridable_creation_funcs.update(creation_funcs) overridable_creation_funcs_loaded(True) return creation_funcs def load_tracking_modules(): if tracking_modules_loaded(): tracking_modules = torch_tracking_modules else: tracking_modules = fetch_modules(os.path.join(current_dir, 'configs/torch_tracking_modules.yml')) torch_tracking_modules.extend(tracking_modules) tracking_modules_loaded(True) return tracking_modules @contextlib.contextmanager def patch_helper( wrap_modules: bool = True, wrap_funcs: bool = True, wrap_creation_funcs: bool = True, wrap_tracking_modules: bool = False, ): """Temporarily monkeypatches the functions and the modules in PyTorch.""" if wrap_modules: if not modules_overrided(): modules = load_overridable_modules() modules_overrided(True) else: wrap_modules = False if wrap_funcs: if not funcs_overrided(): funcs = load_overridable_funcs() o_funcs, o_wrappers = load_torch_overrides_funcs(funcs) funcs_overrided(True) else: wrap_funcs = False if wrap_creation_funcs: if not creation_funcs_overrided(): creation_funcs = load_creation_funcs() creation_funcs_overrided(True) else: wrap_creation_funcs = False if wrap_tracking_modules: if not tracking_modules_overrided(): tracking_modules = load_tracking_modules() tracking_modules_overrided(True) else: wrap_tracking_modules = False if wrap_modules: module_manager = patch_modules(modules, ('__init__', '__setattr__'), (new_init_gen, new_setattr_gen)) module_manager.__enter__() if wrap_funcs: tracked_funcs = [funcs, {torch.Tensor: ['__getattribute__']}] wrappers = [new_func_gen, new_getattr_gen] tracked_funcs.extend(o_funcs) wrappers.extend(o_wrappers) func_manager = patch_funcs(tracked_funcs, wrappers) func_manager.__enter__() if wrap_creation_funcs: creation_func_manager = patch_funcs(creation_funcs, new_creation_func_gen) creation_func_manager.__enter__() if wrap_tracking_modules: tracking_module_manager = patch_modules(tracking_modules, '__new__', new_init_tracking_gen) tracking_module_manager.__enter__() yield True if wrap_modules: module_manager.__exit__(None, None, None) modules_overrided(False) if wrap_funcs: func_manager.__exit__(None, None, None) funcs_overrided(False) if wrap_creation_funcs: creation_func_manager.__exit__(None, None, None) creation_funcs_overrided(False) if wrap_tracking_modules: tracking_module_manager.__exit__(None, None, None) tracking_modules_overrided(False) def check_types(values: typing.Iterable) -> bool: """Checks whether unsupported types are in the args.""" for value in values: if isinstance(value, (tuple, list)): res = check_types(value) if res is not None: return res elif not isinstance(value, (int, float, bool, str, type(None))): return type(value).__name__ return None def check_tensor_type(value) -> bool: """Check whether types are related to torch.Tensor.""" if isinstance(value, (tuple, list)): for item in value: res = check_tensor_type(item) if res: return res elif type(value) is torch.Tensor: return True return False def check_creation_args(args: typing.Iterable) -> typing.Tuple: """Cast arguments of type of Tensor to normal values""" new_args = [] for arg in args: if isinstance(arg, (tuple, list)): new_args.append(check_creation_args(arg)) elif type(arg) is torch.Tensor: if arg.dim() == 0: new_args.append(arg.item()) else: new_args.append(arg.tolist()) else: new_args.append(arg) return tuple(new_args) def tensor_name_from_parts(node_name, node_idx=None, is_constant_node=False): ns = 'self.' if is_constant_node else '' if node_idx is None: return f'{ns}{node_name}' else: if isinstance(node_idx, (list, tuple)): indices_str = ''.join([f'[{i}]' for i in node_idx]) return f'{ns}{node_name}{indices_str}' else: return f'{ns}{node_name}[{node_idx}]' def trace( module: torch.nn.Module, dummy_input: torch.Tensor, eliminate_dead_graph: bool = False, patch_torch_size: bool = False, ) -> TraceGraph: """main function for tracing""" try: with construct_trace_graph(module, dummy_input, eliminate_dead_graph, patch_torch_size) as new_graph: new_graph.init() return new_graph except Exception: traceback.print_exc() if current_graph() is not None: log.error(f'inputs: {[n.unique_name for n in current_graph().input_nodes]}') log.error(f'forwards: {[n.unique_name for n in current_graph().forward_nodes]}') log.error(f'outputs: {[n.unique_name for n in current_graph().output_nodes]}') log.error(f'constants: {[n.unique_name for n in current_graph().constant_nodes]}') quit()