tinynn/graph/tracer.py (2,599 lines of code) (raw):
import contextlib
import copy
import ctypes
import importlib
import inspect
import io
import os
import queue
import re
import sys
import traceback
import typing
import weakref
import types
import torch
import torch.nn as nn
import yaml
import numpy as np
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from tinynn.util.train_util import get_module_device
from tinynn.util.util import get_logger, import_from_path, tensors2ndarray
from ._utils import patch_getitem, revert_getitem, patch_new, revert_new
from . import interop # noqa: F401
# Basic types
class GlobalData(object):
"""The data structure to store data that can be used in this script,
which is a wrapper of a object of a built-in type."""
def __init__(self, value):
super().__init__()
self.value = value
def get_value(self):
"""Returns the inner value of the wrapper"""
return self.value
def set_value(self, value):
"""Sets the inner value of the wrapper"""
self.value = value
def __str__(self):
"""Returns the string representation of the inner object"""
return self.value.__str__()
def __repr__(self):
"""Returns the string representation of the inner object"""
return self.value.__repr__()
def __call__(self, *args):
"""Simplifies the usage of the wrapper
e.g. a = GlobalData(3)
a() -> a.get_value()
a(1) -> a.set_value(1)"""
if len(args) == 0:
return self.get_value()
elif len(args) == 1:
return self.set_value(*args)
else:
raise ValueError(
f'length of input data must in [0,1], but got length: {len(args)} --> args: {args}'
)
def __bool__(self):
"""Returns the actual boolean value of the inner object"""
return self.value.__bool__()
# Constants
# Template for a traced module
MODULE_TEMPLATE = """
%(import_block)s
class %(name_block)s(torch.nn.Module):
def __init__(self):
super().__init__()
%(init_block)s
%(forward_block)s
if __name__ == "__main__":
model = %(name_block)s()
%(load_weight_block)s
model.eval()
model.cpu()
%(input_block)s
output = model(%(input_names)s)
print(output)
"""
# Special math operators
SPECIAL_OPERATORS = ['add', 'and', 'div', 'floordiv', 'lshift', 'mul', 'or', 'pow', 'rshift', 'sub', 'xor', 'truediv']
# Global objects
# Logger
log = get_logger(__name__, 'WARNING')
# Loaded overriable items from the config file
overridable_funcs = {}
overridable_modules = []
overridable_creation_funcs = {}
torch_overrides_funcs = []
torch_overrides_wrappers = []
torch_tracking_modules = []
# Reuse generated wrapper functions and modules
generated_wrapper_funcs = {}
generated_wrapper_modules = {}
# Load state for the override items
overridable_funcs_loaded = GlobalData(False)
overridable_modules_loaded = GlobalData(False)
overridable_creation_funcs_loaded = GlobalData(False)
torch_overrides_funcs_loaded = GlobalData(False)
tracking_modules_loaded = GlobalData(False)
funcs_overrided = GlobalData(False)
modules_overrided = GlobalData(False)
creation_funcs_overrided = GlobalData(False)
tracking_modules_overrided = GlobalData(False)
# Lock for tracing
lock = GlobalData(False)
handle_func_lock = GlobalData(False)
# Whether the constructors get traced
module_constructor_traced = set()
# Current traced graph
current_graph = GlobalData(None)
# Generated module constructor lines
module_constructor_lines = {}
module_constructor_weakrefs = {}
# Directory of the current script
current_dir = os.path.dirname(os.path.abspath(__file__))
# Original module constructor signatures
module_constructor_signatures = {}
# Original values of tracked objects
original_values_for_tracked_objects = {}
# Original module class names
importable_module_names = {}
# Ignore warning for update module parameters
mod_param_update_warning_ignore = GlobalData(False)
# Modules that are skipped while tracing
skip_modules = set()
class TraceNode(object):
"""A basic data structure to represent a node in the computation graph"""
module: typing.Union[torch.nn.Module, 'TraceFunction', 'ConstantNode']
prev_nodes: typing.List['TraceNode']
next_nodes: typing.List['TraceNode']
prev_tensors: typing.List[torch.Tensor]
next_tensors: typing.List[torch.Tensor]
prev_indices: typing.List[typing.Optional[int]]
rev_index: bool
unique_name: str
active: bool
def __init__(
self,
module: typing.Union[torch.nn.Module, 'ConstantNode', 'TraceFunction'],
cur_graph: typing.Optional['TraceGraph'] = None,
):
# Inner module, could either be a `nn.Module`, `ConstantNode` or `TraceFunction`
self.module = module
# Previous and next nodes in the computation graph
self.prev_nodes = []
self.next_nodes = []
# The input and output tensors for the node
self.prev_tensors = []
self.next_tensors = []
# The indices used to retrieve the corresponding tensor
# e.g. torch.chunk() returns [A, B], in which A and B are PyTorch tensors.
# so if we use A in this node, then the corresponding prev_index is 0.
# If the tensor is not a sub item, then `None` should be used.
self.prev_indices = []
# In some nodes, the indices are reversed. For example, for an output node,
# the indices are not used to fetch the items, but to construct a list that
# contains them.
self.rev_index = False
# The current TraceGraph to be processed
# In the trace phase, it can be obtained through `current_graph()`
# Otherwise, you need to pass it explicitly
if cur_graph is None:
cur_graph = current_graph()
# Unique name of the node (the key of the node in the node map in TraceGraph)
if type(module) in (ConstantNode, TraceFunction):
self.unique_name = module.unique_name
else:
self.unique_name = cur_graph.module_unique_name_dict[id(module)]
if isinstance(module, nn.Module) and id(module) in cur_graph.module_original_name_dict:
self.original_name = cur_graph.module_original_name_dict[id(module)]
elif type(module) is ConstantNode:
self.original_name = module.original_name
else:
self.original_name = self.unique_name
# Whether the node is active in the computation graph
self.active = True
# The index of the node in the graph
self.forward_order = 0
# Whether the node is in a quantized graph
self.quantized = False
# Numbering of the name of the node
if cur_graph.global_nodes.get(self.unique_name) is not None:
cur_graph.global_nodes[self.unique_name] += 1
self.unique_name = "_".join([self.unique_name, str(cur_graph.global_nodes[self.unique_name])])
else:
cur_graph.global_nodes[self.unique_name] = 0
def type(self):
"""Returns the original name of the function or the type of the module"""
if type(self.module) is TraceFunction:
return self.module.func_type
return type(self.module)
def kind(self):
"""Returns the kind of the function or the type of the module"""
if type(self.module) is TraceFunction:
return self.module.kind
return type(self.module)
def is_class(self) -> bool:
"""Judges whether it is a class function or not"""
if type(self.module) is TraceFunction:
return self.module.is_class
else:
return False
def full_name(self) -> str:
"""Returns the original full name of the function (including namespace)"""
if type(self.module) in (TraceFunction, ConstantNode):
return self.module.full_name
else:
return f'{type(self.module).__module__}.{type(self.module).__name__}'
def __hash__(self) -> str:
"""Uses the unique name as the hash for the node"""
return self.unique_name
def prev_node_unique_name(self, idx, inplace=False) -> str:
"""A utility function to generate the name of the previous node with index"""
if idx < len(self.prev_nodes) and idx < len(self.prev_indices):
getattr_on_module = False
if (
isinstance(self.prev_nodes[idx].module, torch.nn.Module)
and type(self.module) is TraceFunction
and self.module.is_property
and '.' not in self.module.full_name
):
getattr_on_module = True
actual_inplace = False
if inplace:
actual_inplace = getattr_on_module
if isinstance(self.prev_nodes[idx].module, ConstantNode):
actual_inplace = True
if actual_inplace:
node_name = self.prev_nodes[idx].original_name
else:
node_name = self.prev_nodes[idx].unique_name
node_idx = self.prev_indices[idx]
ns = ''
if type(self.prev_nodes[idx].module) in (ConstantNode, torch.nn.quantized.FloatFunctional):
ns = 'self.'
elif getattr_on_module:
prev_t_ids = set(id(t) for t in self.prev_tensors)
next_t_ids = set(id(t) for t in self.prev_nodes[idx].next_tensors)
if len(prev_t_ids & next_t_ids) == 0:
ns = 'self.'
if node_idx is None:
return f'{ns}{node_name}'
else:
if isinstance(node_idx, (list, tuple)):
indices_str = ''.join([f'[{i}]' for i in node_idx])
return f'{ns}{node_name}{indices_str}'
else:
return f'{ns}{node_name}[{node_idx}]'
else:
return ''
class ConstantNode(object):
"""A data structure for runtime-defined constants"""
def __init__(
self,
data: typing.List,
dtype: torch.dtype,
shape: torch.Size,
unique_name: typing.Optional[str] = None,
original_name: typing.Optional[str] = None,
):
# Raw data (list)
self.data = data
# Data shape
self.shape = tuple(shape)
# Data type
self.dtype = str(dtype)
# Please refer to the the description of those properties in `TraceFunction`
self.kind = 'tensor'
self.func_type = 'tensor'
self.full_name = 'torch.tensor'
self.is_class = False
self.is_parameter = False
self.is_persistent = False
self.requires_grad = False
self.data_str = None
# Numbering of the name of the node
if current_graph().global_functions.get(self.kind, None) is None:
current_graph().global_functions[self.kind] = 0
else:
current_graph().global_functions[self.kind] += 1
if unique_name is None:
self.unique_name = "_".join([self.kind, str(current_graph().global_functions[self.kind])])
else:
self.unique_name = unique_name
self.inplace = original_name is not None
if original_name is None:
self.original_name = self.unique_name
else:
self.original_name = original_name
def parse(self, convert_to_parameter: bool = False, persistent: bool = False, requires_grad: bool = False):
def _stringify_list(content) -> str:
"""Convert a list of objects to a string"""
if isinstance(content, (list, tuple)):
sub_contents = []
for item in content:
sub_contents.append(_stringify_list(item))
inner_content = ', '.join(sub_contents)
return f'[{inner_content}]'
elif type(content) in (int, float, bool):
return str(content)
elif type(content) is str:
return f'"{content}"'
# If `convert_to_parameter` is `True`, the content of the data will not be written inline.
self.is_parameter = convert_to_parameter
self.is_persistent = persistent
self.requires_grad = requires_grad
if not persistent:
self.data_str = f'{_stringify_list(self.data)}'
return self
class TraceFunction(object):
"""A data structure for traced functions"""
def __init__(
self, full_name: str, is_class: bool = False, is_property: bool = False, prefix: typing.Optional[str] = None
):
super().__init__()
# The base name of the function
self.func_type = full_name.split('.')[-1]
# The class name of the function
# It can be acquired by removing underlines in the base name of the function for special math
# operators and inline functions
self.kind = None
if self.func_type.endswith('__') and self.func_type.startswith('__'):
inner_name = self.func_type[2:-2]
if len(inner_name) > 1 and inner_name[0] in ('i', 'r'):
inner_op = inner_name[1:]
if inner_op in SPECIAL_OPERATORS:
self.kind = inner_op
if self.kind is None:
self.kind = inner_name
if self.kind is None:
if self.func_type.endswith('_'):
self.kind = self.func_type[:-1]
else:
self.kind = self.func_type
# Numbering of the nodes
if current_graph().global_functions.get(self.kind, None) is None:
current_graph().global_functions[self.kind] = 0
else:
current_graph().global_functions[self.kind] += 1
if prefix is None:
prefix = ""
# Unique name
self.unique_name = prefix + "_".join([self.kind, str(current_graph().global_functions[self.kind])]) + "_f"
# The input tensors of the function
self.prev_tensors = []
# The name of the function (including namespace)
self.full_name = full_name
# Whether it is a class function/property
self.is_class = is_class
# Whether it is a property
self.is_property = is_property
# Alias
self.aliases = None
# Arguments
self.args = None
self.kwargs = None
self.args_string = None
self.args_parsed = None
self.args_parsed_origin = None
self.tensor_names = None
self.original_tensor_names = None
self.args_template = None
self.args_template_no_self = None
self.args_offset = None
def __repr__(self) -> str:
"""Returns the string representation of the object"""
arg_len = len(self.tensor_names)
if arg_len > 0:
prefix = 'lambda args: '
else:
prefix = 'lambda: '
extra_expr = self.extra_expr('args')
expr = f'{prefix}{extra_expr}'
return expr
def extra_expr(self, prefix=None, first=None, original=False):
"""Returns the extra string representation of the object"""
arg_len = len(self.tensor_names)
if arg_len == 0:
expr = f'{self.full_name}({self.args_template})'
else:
if prefix is None:
if original:
tensor_names = [
(o_name if o_name.startswith('self.') else f'self.{o_name}')
if u_name.startswith('self.')
else u_name
for u_name, o_name in zip(self.tensor_names, self.original_tensor_names)
]
else:
tensor_names = self.tensor_names
else:
tensor_names = [f'{prefix}[{i}]' for i in range(arg_len)]
if first is not None:
tensor_names = [first] + tensor_names[1:]
if self.is_class:
if self.is_property:
expr = f'{tensor_names[0]}.{self.func_type}'
elif self.func_type == '__getitem__':
args = self.args_string_no_self
if not args.startswith('[') and not args.endswith(']'):
args = f'[{args}]'
expr = f'{tensor_names[0]}{args}'
else:
full_template = f'{{}}.{self.func_type}({self.args_template_no_self})'
expr = full_template.format(*tensor_names)
else:
full_template = f'{self.full_name}({self.args_template})'
expr = full_template.format(*tensor_names)
return expr
def __call__(self, *args, **kwargs):
"""Calls the function with a list of tensor inputs"""
if len(kwargs) > 0:
log.warning('Keyword arguments are ignored when calling TraceFunction')
if len(args) == 0:
arg_len = 0
else:
if len(args) > 1:
log.warning('Multiple arguments are passed in, but all but the first one will be ignored')
if not isinstance(args[0], (tuple, list)):
log.error('Only tuple or list is accepted here')
assert False
arg_len = len(args[0])
expected_len = len(self.tensor_names)
if arg_len != expected_len:
log.error(f'Wrong number of input tensors, expected: {expected_len}, but got {arg_len}')
assert False
expr = self.extra_expr('args[0]')
return eval(expr)
def parse_args(self, *args, **kwargs):
"""Sets the string representation of the arguments"""
def _tensor_name(a, convert_to_parameter=False, original=False):
"""Get the tensor name from the computation graph"""
ns = ''
if constant_handler(a, self.unique_name, self.full_name):
ns = 'self.'
pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
if original:
node = current_graph().nodes_map[pre_node_name]
pre_node_name = node.original_name
else:
pre_node_name = current_graph().tensor_pre_node_dict[id(a)]
node = current_graph().nodes_map[pre_node_name]
if original:
pre_node_name = node.original_name
if type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional):
ns = 'self.'
if id(a) in current_graph().tensor_pre_index_dict:
pre_node_index = current_graph().tensor_pre_index_dict[id(a)]
log.debug(f'pre_index gen func {self.kind}: {pre_node_index}')
if isinstance(pre_node_index, (list, tuple)):
indices_str = ''.join([f'[{i}]' for i in pre_node_index])
return f"{ns}{pre_node_name}{indices_str}"
else:
return f"{ns}{pre_node_name}[{pre_node_index}]"
else:
return f"{ns}{pre_node_name}"
def _escape_arg(arg: str):
"""Escapes the special characters in the argument string"""
for c in ('{', '}'):
if c in arg:
arg = arg.replace(c, f'{c}{c}')
return arg
def _parse_args(arg):
"""Converts the argument to a list of strings"""
new_arg = []
for a in arg:
if isinstance(a, (list, tuple, torch.Size)):
new_arg.append(_parse_args(a))
elif type(a) in (torch.Tensor, torch.nn.Parameter) or (
type(a) in (torch.dtype, torch.device, torch.Size) and id(a) in current_graph().tensor_pre_node_dict
):
self.prev_tensors.append(a)
self.tensor_names.append(_tensor_name(a))
self.original_tensor_names.append(_tensor_name(a, original=True))
new_arg.append('{}')
elif type(a) in (str, torch.device):
new_arg.append(_escape_arg(f"\'{a}\'"))
elif type(a) in (int, bool, torch.dtype):
new_arg.append(str(a))
elif type(a) is float:
str_arg = str(a)
if str_arg in ('nan', 'inf', '-inf'):
new_arg.append(f"float('{str_arg}')")
else:
new_arg.append(str_arg)
elif a is None:
new_arg.append('None')
elif a is Ellipsis:
new_arg.append('...')
elif type(a) is slice:
t = (a.start, a.stop, a.step)
parts = []
for x in t:
if x is None:
parts.append('')
else:
parts.extend(_parse_args([x]))
r = ':'.join(parts)
if r.endswith(':'):
r = r[:-1]
new_arg.append(r)
elif isinstance(a, torch.nn.quantized.FloatFunctional):
float_functional_cls = type(a)
module_constructor_lines[id(a)] = f'{qualified_name(float_functional_cls, short=True)}()'
new_node = TraceNode(a)
current_graph().nodes_map[new_node.unique_name] = new_node
current_graph().other_init_nodes.append(new_node)
current_graph().tensor_pre_node_dict[id(a)] = new_node.unique_name
self.tensor_names.append(_tensor_name(a))
self.original_tensor_names.append(_tensor_name(a, original=True))
self.prev_tensors.append(a)
new_arg.append('{}')
elif isinstance(a, nn.Module):
unique_name = current_graph().module_unique_name_dict[id(a)]
current_graph().tensor_pre_node_dict[id(a)] = unique_name
self.tensor_names.append(f'self.{unique_name}')
self.original_tensor_names.append(_tensor_name(a, original=True))
self.prev_tensors.append(a)
new_arg.append('{}')
else:
log.error(f"unsupported type {type(a)} while generating arg for func {self.full_name}")
assert False
return new_arg
self.tensor_names = []
self.original_tensor_names = []
self.prev_tensors.clear()
arg_str = _parse_args(args)
kw_items = kwargs.items()
if kw_items:
kw_keys, kw_vals = zip(*kw_items)
kw_val_strs = _parse_args(kw_vals)
for (k, v) in zip(kw_keys, kw_val_strs):
if type(v) is list:
v_str = self._flatten_list(v)
arg_str.append(f"{k}={v_str}")
else:
arg_str.append(f"{k}={v}")
self.args_parsed = copy.deepcopy(arg_str)
self.kwargs = copy.deepcopy(kwargs)
self.args_parsed_origin = copy.deepcopy(self.args_parsed)
self.args_to_string(self.args_parsed)
return self
def _flatten_list(self, content):
"""Flatten a list of nested list or string into a string"""
if isinstance(content, list):
sub_contents = []
for item in content:
sub_contents.append(self._flatten_list(item))
inner_content = ', '.join(sub_contents)
return f'[{inner_content}]'
else:
return content
def args_to_string(self, arg_str):
for i in range(len(arg_str)):
if type(arg_str[i]) is list:
arg_str[i] = self._flatten_list(arg_str[i])
try:
self.args_template = ", ".join(arg_str)
self.args_template_no_self = ", ".join(arg_str[1:])
self.args_offset = 1 if arg_str[0] == '{}' else 0
self.update_args_string()
except Exception:
log.error(f"Error generating argument string for function {self.full_name}")
assert False
return self
def update_tensor_name(self, index, new_name):
"""Updates the tensor name at the given index"""
self.tensor_names[index] = new_name
def replace_tensor_name(self, old_name, new_name):
"""Replaces the specific tensor name with the given one"""
for i, name in enumerate(self.tensor_names):
if name == old_name:
self.tensor_names[i] = new_name
def update_args_string(self):
"""Updates the string representation according to the templates and the tensor names"""
if self.args_template:
self.args_string = self.args_template.format(*self.tensor_names)
if self.args_template_no_self:
self.args_string_no_self = self.args_template_no_self.format(*self.tensor_names[self.args_offset :])
def get_tensor_name(self, index, original):
"""Retrieves the tensor name at the given index"""
if original:
return self.original_tensor_names[index]
else:
return self.tensor_names[index]
def add_alias(self, name, head=False):
"""Adds an aliases of the tensor"""
if self.aliases is None:
self.aliases = []
if head:
self.aliases.insert(0, name)
else:
self.aliases.append(name)
def get_aliases(self):
"""Retrieves the aliases of the tensor"""
return self.aliases
@contextlib.contextmanager
def no_catch():
"""Context manager for tracing nodes. Use it to avoid tracing the nodes recursively."""
if lock():
yield False
else:
lock(True)
yield True
lock(False)
@contextlib.contextmanager
def no_catch_handle_func():
"""Context manager for tracing nodes. Use it to avoid hacking into handle_torch_function recursively."""
if handle_func_lock():
yield False
else:
handle_func_lock(True)
yield True
handle_func_lock(False)
def args_as_string(args, kwargs):
"""String representation of the args and the keyword args"""
cleaned_args = [f'"{arg}"' if type(arg) is str else str(arg) for arg in args]
args_content = ', '.join(cleaned_args)
kwargs_content = ', '.join((f'{k}="{v}"' if type(v) is str else f'{k}={v}' for k, v in kwargs.items()))
args_connector = '' if args_content == '' or kwargs_content == '' else ', '
full_args_content = f'{args_content}{args_connector}{kwargs_content}'
return full_args_content
def new_setattr_gen(orig_setattr, key: str):
"""Wrapper function for the __setattr__ functions of the modules in PyTorch"""
log.debug(f'registered module setattr wrapper: {key}')
def new_setattr(obj, name, value):
log.debug(f'{key} in setattr function wrapper')
related = True
log.debug(f'{key} before with block, lock: {lock}')
with no_catch() as res:
if res:
if id(obj) not in module_constructor_traced:
related = False
class_type = type(obj)
if related and not hasattr(class_type, '__constants__'):
related = False
if related and name not in class_type.__constants__:
related = False
if related:
class_name = '.'.join(key.split('.')[:-1])
log_func = log.warning
if mod_param_update_warning_ignore():
log_func = log.debug
log_func(
f'The constant property `{name}` of {class_name} is changed. We need to drop the original'
' constructor line.'
)
module_constructor_traced.remove(id(obj))
del module_constructor_lines[id(obj)]
return orig_setattr(obj, name, value)
log.debug(f'{key} after with block, lock: {lock}')
return new_setattr
def new_module_getattr_gen(orig_getattr, key: str, res: typing.Dict[str, bool]):
log.debug(f'registered module getattr wrapper: {key}')
def new_getattr(obj, name):
log.debug(f'{key} in module getattr function wrapper')
result = orig_getattr(obj, name)
result_is_str = isinstance(result, str)
is_dict = name == '__dict__'
related = result_is_str or is_dict
# If the following conditions are satisfied
# a. the type of the result is `str` or the key is `__dict__`
# b. the 2nd last frame is on the `extra_repr` function or `__repr__` function
# c. the line is not starting with `if `
# then we should patch the variable.
if related:
last_frame = inspect.currentframe().f_back
func_name = last_frame.f_code.co_name
related = False
if func_name in ('extra_repr', '__repr__'):
if is_dict:
related = True
else:
fn = last_frame.f_code.co_filename
ln = last_frame.f_lineno
res_key = f'{fn}_{ln}'
if res_key in res:
related = res[res_key]
else:
lines = inspect.getframeinfo(last_frame)[3]
related = not re.match(r' *if .*', lines[0]) and not re.match(r'.*repr\(self\..*\)', lines[0])
res[res_key] = related
if related:
if result_is_str:
return f'"{result}"'
elif is_dict:
orig = result
result = {}
for k, v in orig.items():
if type(v) is str and (not k.startswith('__') and not k.endswith('__')):
result[k] = f'"{orig[k]}"'
else:
result[k] = v
return result
return new_getattr
def constant_handler(
tensor: torch.Tensor, node_name: typing.Optional[str] = None, type_name: typing.Optional[str] = None
) -> bool:
"""Appends the constant tensor to the computation graph
Args:
tensor (torch.Tensor): The constant tensor
node_name (typing.Optional[str], optional): The name of the node that depends on that tensor. Defaults to None.
type_name (typing.Optional[str], optional): The type of the node that depends on that tensor. Defaults to None.
Returns:
bool: Whether it is a new constant tensor
"""
if id(tensor) not in current_graph().tensor_pre_node_dict:
if id(tensor) in current_graph().parameter_module_dict:
mod_id = current_graph().parameter_module_dict[id(tensor)]
unique_name = current_graph().module_unique_name_dict[mod_id]
if unique_name in current_graph().related_modules:
mod = ctypes.cast(mod_id, ctypes.py_object).value
node = TraceNode(mod)
add_forward_node(node, [], [])
current_graph().tensor_pre_node_dict[mod_id] = node.unique_name
key = current_graph().parameter_original_name_dict[id(tensor)].split('.')[-1]
trace_func = TraceFunction(key, True, True).parse_args(mod)
trace_node = TraceNode(trace_func)
add_forward_node(trace_node, [mod], tensor)
return False
if not tensor.is_leaf:
if node_name is None:
node_name = 'unknown node'
if type_name is None:
type_name = 'unknown'
log.error(f'Connection is lost when generating code for {node_name} of type {type_name}')
else:
# constant tensor generation
log.warning('Constant generation is experimental and may yield error')
convert_to_parameter = False
persistent = False
requires_grad = tensor.requires_grad
if isinstance(tensor, torch.nn.Parameter):
convert_to_parameter = True
if tensor.numel() > 50:
persistent = True
raw_data = tensor.tolist()
unique_name = current_graph().parameter_unique_name_dict.get(id(tensor), None)
original_name = current_graph().parameter_original_name_dict.get(id(tensor), None)
with no_catch():
constant_node = ConstantNode(raw_data, tensor.dtype, tensor.shape, unique_name, original_name).parse(
convert_to_parameter, persistent, requires_grad
)
trace_node = TraceNode(constant_node)
add_constant_node(trace_node, tensor)
return True
else:
return False
def new_getattr_gen(orig_getattr, key: str, is_class: bool):
"""Wrapper function for the __getattribute__ functions of the modules in PyTorch"""
log.debug(f'registered module getattr wrapper: {key}')
def new_getattr(obj, name):
log.debug(f'{key} in getattr function wrapper')
related = False
log.debug(f'{key} before with block, lock: {lock}')
with no_catch() as res:
result = orig_getattr(obj, name)
if current_graph() is None:
related = False
if name in ('device', 'shape', 'data', 'dtype'):
related = True
if res:
if related:
# Also the property should be constant if the result object is unchanged.
# Only create a new node when there isn't one.
if (
id(result) in current_graph().tensor_pre_node_dict
and id(result) in original_values_for_tracked_objects
and original_values_for_tracked_objects[id(result)] == result
):
node_name = current_graph().tensor_pre_node_dict[id(result)]
trace_node = current_graph().nodes_map[node_name]
if trace_node.module.is_property and trace_node.module.func_type == 'shape':
result = tuple(trace_node.next_tensors)
else:
# Handling dynamic shape
# If the torch.Size object is generated by a tensor,
# then we connect it to the graph.
# Otherwise, don't track it.
old_result = None
if type(result) is torch.Size and isinstance(obj, torch.Tensor):
# Create a list of new tensors for the sake of tracking
# The reason to use that instead of a tensor is stated below.
# e.g. Users may use the following clause to deal with sizes
# x, y = tensor.size()
# Currently, there is no way to trace it.
# However, by doing this, if user calls `numel` on the `torch.Size`
# object, it will now throw an exception.
# TODO: Fix the case if user calls `numel` on `torch.Size`
constant_handler(obj, type_name=key)
original_values_for_tracked_objects[id(result)] = copy.deepcopy(result)
new_result = []
for elem in result:
new_result.append(torch.tensor(elem))
current_graph().tensor_pre_node_dict[
id(new_result[-1])
] = current_graph().tensor_pre_node_dict[id(obj)]
old_result = result
result = tuple(new_result)
log.debug(f'{key} is called with {name}')
new_key = key.replace('__getattribute__', name)
trace_func = TraceFunction(new_key, True, True).parse_args(obj)
trace_node = TraceNode(trace_func)
if old_result is not None:
current_graph().tensor_pre_node_dict[id(old_result)] = trace_node.unique_name
add_forward_node(trace_node, trace_func.prev_tensors, result)
log.debug(f'{key} after with block, lock: {lock}')
return result
return new_getattr
def new_init_tracking_gen(orig_init, key: str):
"""Wrapper function for the init functions of the modules in PyTorch"""
log.debug(f'registered module init tracking wrapper: {key}')
def new_init_tracking(obj, args, kwargs):
log.debug(f'{key} in init tracking function wrapper')
with no_catch():
is_tensor = torch.is_tensor(args[0])
if is_tensor:
return args[0].unbind(0)
else:
with no_catch():
return tuple([torch.as_tensor(x) for x in args[0]])
return new_init_tracking
def new_init_gen(orig_init, key: str):
"""Wrapper function for the init functions of the modules in PyTorch"""
log.debug(f'registered module init wrapper: {key}')
def new_init(obj, *args, **kwargs):
log.debug(f'{key} in init function wrapper')
module_constructor_traced.add(id(obj))
init_fullname = key
class_fullname = '.'.join(init_fullname.split('.')[:-1])
log.debug(f'{key} before with block, lock: {lock}')
with no_catch() as res:
if id(obj) not in module_constructor_lines or (
id(obj) in module_constructor_weakrefs and module_constructor_weakrefs[id(obj)]() is None
):
if not res:
log.warning(
f'Failed to acquire the tracing lock while tracing {init_fullname}, which is unexpected.'
)
log.debug(f'{key} in with block, lock: {lock}')
actual_class_name = qualified_name(type(obj))
if actual_class_name == class_fullname:
rckwa = check_types(kwargs.values())
rca = check_types(args)
err_type = rca or rckwa
if err_type:
log.warning(
f'Constructor of {class_fullname} has arguments of type {err_type} which is unsupported'
)
log.warning(f' Args: {args}')
log.warning(f' Keyword args: {kwargs}')
else:
log.info(f'Constructor of {class_fullname} registered')
full_args_content = args_as_string(args, kwargs)
orig_constructor_line = f'{class_fullname}({full_args_content})'
module_constructor_lines[id(obj)] = orig_constructor_line
module_constructor_weakrefs[id(obj)] = weakref.ref(obj)
else:
module_constructor_traced.remove(id(obj))
if not actual_class_name.startswith('torch.'):
log.warning(f'Constructor of class {actual_class_name} is not captured')
orig_init(obj, *args, **kwargs)
log.debug(f'{key} after with block, lock: {lock}')
return new_init
def new_func_gen(orig_func, key: str, is_class: bool):
"""Wrapper function for functions in PyTorch"""
log.debug(f'registered function wrapper: {key}')
def new_func(*args, **kwargs):
log.debug(f'{key} in function wrapper')
related = False
log.debug(f'{key} before with block, lock: {lock}')
with no_catch() as res:
result = orig_func(*args, **kwargs)
if res and current_graph() is not None:
log.debug(f'{key} in with block, lock: {lock}')
if key == 'torch.Tensor.size' and len(args) > 1:
# Tracking torch.Tensor.size with optional int argument
result = torch.tensor(result)
if type(result) is torch.Size:
# Handling dynamic shape
# If the torch.Size object is generated by a tensor,
# then we connect it to the graph.
# Otherwise, don't track it.
if len(args) > 0 and type(args[0]) is torch.Tensor:
# Create a list of new tensors for the sake of tracking
# The reason to use that instead of a tensor is stated below.
# e.g. Users may use the following clause to deal with sizes
# x, y = tensor.size()
# Currently, there is no way to trace it.
# However, by doing this, if user calls `numel` on the `torch.Size`
# object, it will now throw an exception.
# TODO: Fix the case if user calls `numel` on `torch.Size`
new_result = []
for elem in result:
new_result.append(torch.tensor(elem))
current_graph().tensor_pre_node_dict[
id(new_result[-1])
] = current_graph().tensor_pre_node_dict[id(args[0])]
result = tuple(new_result)
related = True
elif type(result) in (torch.dtype, torch.device):
related = True
else:
related = check_tensor_type(result)
log.debug(f'{key} after with block, lock: {lock}')
if related:
log.debug(f'tracing {key} in function wrapper')
with no_catch() as res:
if res:
trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs)
trace_node = TraceNode(trace_func)
modified_result = noop_handler(trace_node, trace_func.prev_tensors, result)
if modified_result is not None:
result = modified_result
add_forward_node(trace_node, trace_func.prev_tensors, result)
log.debug(f'tracing {key} function wrapper complete')
return result
return new_func
def new_has_torch_func_gen(orig_func, key: str, is_class: bool):
"""Wrapper function for functions in PyTorch"""
log.debug(f'registered has torch func wrapper: {key}')
def new_func(*args, **kwargs):
with no_catch_handle_func() as res:
return (res and not lock()) or orig_func(*args, **kwargs)
return new_func
def new_handle_func_gen(orig_func, key: str, is_class: bool):
"""Wrapper function for functions in PyTorch"""
log.debug(f'registered has torch func wrapper: {key}')
def new_func(func, tracked_args, *args, **kwargs):
if lock():
return orig_func(func, tracked_args, *args, **kwargs)
else:
with no_catch_handle_func():
return func(*args, **kwargs)
return new_func
def new_creation_func_gen(orig_func, key: str, is_class: bool):
"""Wrapper function for functions in PyTorch"""
log.debug(f'registered creation function wrapper: {key}')
def new_func(*args, **kwargs):
log.debug(f'{key} in creation function wrapper')
with no_catch() as res:
result = orig_func(*args, **kwargs)
log.debug(f'tracing {key} in creation function wrapper')
with no_catch() as res:
if res:
trace_func = TraceFunction(key, is_class).parse_args(*args, **kwargs)
trace_node = TraceNode(trace_func)
add_forward_node(trace_node, trace_func.prev_tensors, result)
return result
return new_func
def fetch_modules(config: typing.Optional[str] = None):
"""Fetches the functions from the config."""
if config is None:
config = os.path.join(current_dir, 'configs/torch_module_override.yml')
modules = []
with open(config, 'r') as f:
module_dict = yaml.load(f, yaml.SafeLoader)
for ns, module_names in module_dict.items():
try:
scope = importlib.import_module(ns)
except ImportError:
pass
for module_name in module_names:
if hasattr(scope, module_name):
module = getattr(scope, module_name)
modules.append(module)
importable_module_names[module] = f'{ns}.{module_name}'
if hasattr(module, '__init__'):
constructor = module.__init__
module_constructor_signatures[module] = inspect.signature(constructor).parameters.values()
return modules
def fetch_funcs(config: typing.Optional[str] = None):
"""Fetches the functions from the config."""
if config is None:
version_parts = torch.__version__.split('.')
if int(version_parts[0]) == 1:
if int(version_parts[1]) < 6:
version_parts[1] = '6'
if int(version_parts[1]) > 12:
version_parts[1] = '12'
elif int(version_parts[0]) == 2:
if int(version_parts[1]) > 0:
version_parts[1] = '0'
else:
log.warning(f'Your PyTorch version is unsupported: {torch.__version__}')
version_parts = ['1', '6']
version_str = '_'.join(version_parts[:2])
config = os.path.join(current_dir, f'configs/torch_func_override_{version_str}.yml')
modules = []
with open(config, 'r') as f:
module_dict = yaml.load(f, yaml.SafeLoader)
new_dict = {}
for ns, module_names in module_dict.items():
log.debug(f'Attempting to load {ns}')
try:
spec = importlib.util.find_spec(ns)
except ImportError:
continue
if spec is None:
modules = ns.split('.')
ns = '.'.join(modules[:-1])
typename = modules[-1]
spec = importlib.util.find_spec(ns)
if spec is None:
log.warning(f"Error importing {ns}, which may not be a module")
continue
scope = importlib.import_module(ns)
if hasattr(scope, typename):
scope = getattr(scope, typename)
else:
log.warning(f"Error importing {ns}.{typename}")
continue
else:
scope = importlib.import_module(ns)
modules = []
for module_name in module_names:
if hasattr(scope, module_name):
modules.append(module_name)
importable_module_names[getattr(scope, module_name)] = f'{ns}.{module_name}'
new_dict[scope] = modules
return new_dict
def prepare_torch_overrides_funcs(funcs):
tracked_funcs = []
wrappers = []
if hasattr(torch, 'overrides') and inspect.ismodule(torch.overrides):
all_has_torch_func_names = ['has_torch_function', 'has_torch_function_unary', 'has_torch_function_variadic']
all_handle_func_names = ['handle_torch_function']
has_torch_func_names = []
for n in all_has_torch_func_names:
if hasattr(torch.overrides, n):
has_torch_func_names.append(n)
handle_func_names = []
for n in all_handle_func_names:
if hasattr(torch.overrides, n):
handle_func_names.append(n)
has_torch_funcs = {torch.overrides: has_torch_func_names}
handle_funcs = {torch.overrides: handle_func_names}
for ns in funcs.keys():
if ns == torch.Tensor:
if hasattr(torch, '_tensor') and inspect.ismodule(torch._tensor):
ns = torch._tensor
else:
ns = sys.modules['torch.tensor']
ns_has_torch_func_names = []
for k in has_torch_func_names:
if hasattr(ns, k):
ns_has_torch_func_names.append(k)
ns_handle_func_names = []
for k in handle_func_names:
if hasattr(ns, k):
ns_handle_func_names.append(k)
if len(ns_has_torch_func_names) > 0:
has_torch_funcs.update({ns: ns_has_torch_func_names})
if len(ns_handle_func_names) > 0:
handle_funcs.update({ns: ns_handle_func_names})
tracked_funcs.extend((has_torch_funcs, handle_funcs))
wrappers.extend((new_has_torch_func_gen, new_handle_func_gen))
return tracked_funcs, wrappers
def qualified_name(module, item: typing.Optional[str] = None, short: bool = False):
if module in importable_module_names:
obj_key = importable_module_names[module]
elif hasattr(module, '__module__'):
obj_key = f'{module.__module__}.{module.__name__}'
if short:
mod = module.__module__
name = module.__name__
pos = [i for i, x in enumerate(mod) if x == '.']
for i in pos:
ns = mod[:i]
if ns in sys.modules:
cur_mod = sys.modules[ns]
if hasattr(cur_mod, name) and getattr(cur_mod, name) == module:
obj_key = f'{ns}.{name}'
break
elif isinstance(module, str):
obj_key = module
else:
obj_key = module.__name__
if item is not None:
return f'{obj_key}.{item}'
else:
return obj_key
@contextlib.contextmanager
def patch(object, name, gen, *args, **kwargs):
"""Temporarily monkeypatches an object."""
pre_patched_value = getattr(object, name)
setattr(object, name, gen(pre_patched_value, *args, **kwargs))
yield object
setattr(object, name, pre_patched_value)
@contextlib.contextmanager
def patch_modules(objects, names, gens):
"""Temporarily monkeypatches the modules in PyTorch."""
if not isinstance(names, (tuple, list)) and not isinstance(gens, (tuple, list)):
names = (names,)
gens = (gens,)
pre_patched_values = {}
for obj in objects:
for name, gen in zip(names, gens):
key = qualified_name(obj, name, short=True)
pre_patched_values[key] = getattr(obj, name)
generated_wrapper_modules.setdefault(key, gen(pre_patched_values[key], key))
if key == 'torch.Size.__new__':
pre_patched_values[key] = patch_new(obj, generated_wrapper_modules[key])
else:
setattr(obj, name, generated_wrapper_modules[key])
yield objects
for obj in objects:
for name in names:
key = qualified_name(obj, name, short=True)
pre_patched_value = pre_patched_values[key]
if key == 'torch.Size.__new__':
revert_new(obj, pre_patched_value)
else:
setattr(obj, name, pre_patched_value)
@contextlib.contextmanager
def patch_funcs(object_dicts, gens):
"""Temporarily monkeypatches the functions in PyTorch."""
if not isinstance(object_dicts, (tuple, list)) and not isinstance(gens, (tuple, list)):
object_dicts = (object_dicts,)
gens = (gens,)
pre_patched_value_dict = {}
for object_dict, gen in zip(object_dicts, gens):
for obj, names in object_dict.items():
for name in names:
key = qualified_name(obj, name)
if key in pre_patched_value_dict:
log.warning(f'{key} declared more than once in torch_func_override.yml, skipping')
else:
if key == 'torch.Tensor.__getitem__':
pre_patched_value_dict[key] = getattr(torch._C._TensorBase, '__getitem__')
generated_wrapper_funcs.setdefault(
key, gen(pre_patched_value_dict[key], key, hasattr(obj, '__module__'))
)
new_func = generated_wrapper_funcs[key]
key = qualified_name(obj, name)
pre_patched_value_dict[key] = patch_getitem(obj, new_func)
else:
pre_patched_value_dict[key] = getattr(obj, name)
generated_wrapper_funcs.setdefault(
key, gen(pre_patched_value_dict[key], key, hasattr(obj, '__module__'))
)
setattr(obj, name, generated_wrapper_funcs[key])
yield object_dict
for object_dict, gen in zip(object_dicts, gens):
for obj, names in object_dict.items():
for name in names:
key = qualified_name(obj, name)
if key == 'torch.Tensor.__getitem__':
key = qualified_name(obj, name)
revert_getitem(obj, pre_patched_value_dict[key])
else:
setattr(obj, name, pre_patched_value_dict[key])
def get_constructor_args(actual_class_type):
"""Gets the args of the original constructor for a known module class"""
if actual_class_type in module_constructor_signatures:
return module_constructor_signatures[actual_class_type]
else:
return inspect.signature(actual_class_type.__init__).parameters.values()
def gen_module_constrctor_line(module, mod_cache=None):
"""Generates the constructor line for a loaded module"""
ignored_args = {
'torch.nn.ZeroPad2d': ['value'],
'torch.nn.UpsamplingNearest2d': ['mode'],
'torch.nn.UpsamplingBilinear2d': ['mode'],
}
legacy_modules = {
'torch.nn.FractionalMaxPool2d',
'torch.nn.FractionalMaxPool3d',
'torch.nn.TransformerEncoderLayer',
'torch.nn.TransformerDecoderLayer',
'torch.nn.Upsample',
}
def _skip_ignored_args(name, *args, **kwargs):
iargs = ignored_args[name]
for arg in iargs:
if isinstance(arg, int):
if arg >= 0 and arg < len(args):
del args[arg]
elif isinstance(arg, str):
if arg in kwargs:
del kwargs[arg]
else:
raise AttributeError(f'Unknown type {type(arg)} in ignored args')
return args_as_string(args, kwargs)
name = qualified_name(type(module), short=True)
if mod_cache is None:
mod_cache = {}
module_cls = type(module)
if name in legacy_modules:
if hasattr(module_cls, '__constants__') and hasattr(module_cls, '__init__'):
known_constants = set(module_cls.__constants__)
arg_info = get_constructor_args(module_cls)
args = []
for p in arg_info:
prop_name = p.name
if prop_name == 'self':
continue
if prop_name not in known_constants:
if not hasattr(module_cls, prop_name) or not isinstance(
getattr(module_cls, prop_name), (type(None), str, int, float, bool)
):
log.warning(
f'Argument "{prop_name}" of the constructor of {name} is not a known constant, skipping'
)
continue
if p.default is p.empty:
prop_value = getattr(module, prop_name)
args.append(f'{prop_value}')
else:
# If loading from legacy model, then the property can be missing.
# Thus, we should skip it so as to use the default value.
if not hasattr(module, prop_name):
continue
prop_value = getattr(module, prop_name)
# Appending keyword args
default_value = p.default
# Skip the arg if it has the same value with the default one
if default_value == prop_value:
continue
if type(prop_value) is str:
prop_value_str = f'"{prop_value}"'
else:
prop_value_str = prop_value
args.append(f'{prop_name}={prop_value_str}')
mid = ', '.join(args)
result = f'{name}({mid})'
else:
with patch(module_cls, '__getattribute__', new_module_getattr_gen, name, mod_cache):
if getattr(module_cls, '__repr__') == getattr(torch.nn.Module, '__repr__'):
result = f'{name}({module.extra_repr()})'
else:
ns = '.'.join(name.split('.')[:-1])
result = f'{ns}.{repr(module)}'
if name in ignored_args:
start_pos = result.find('(')
end_pos = result.rfind(')')
head = result[: start_pos + 1]
tail = result[end_pos:]
mid = result[start_pos + 1 : end_pos]
result = head + eval(f'_skip_ignored_args("{name}", {mid})') + tail
return result, mod_cache
def noop_handler(node, inputs, outputs):
"""Generate modified outputs if the inputs and the outputs are the same"""
with no_catch():
is_list = False
is_tuple = False
modified = False
if isinstance(outputs, list):
is_list = True
elif isinstance(outputs, tuple):
is_tuple = True
elif isinstance(outputs, torch.Tensor):
if id(outputs) in current_graph().tensor_pre_node_dict:
return outputs.view(outputs.shape)
if is_list or is_tuple:
is_tracked = [
isinstance(t, torch.Tensor) and id(t) in current_graph().tensor_pre_node_dict for t in outputs
]
modified = any(is_tracked)
if modified:
new_outputs = [t.view(t.shape) if d else t for t, d in zip(outputs, is_tracked)]
if is_tuple:
return tuple(new_outputs)
else:
return new_outputs
else:
return None
def add_input_node(node: TraceNode, output_tensors):
"""Adds an input node to the current computation graph"""
assert node is not None
if not isinstance(output_tensors, (list, tuple)):
output_tensors = [output_tensors]
node.next_tensors.extend(output_tensors)
for t in output_tensors:
current_graph().tensor_pre_node_dict[id(t)] = node.unique_name
current_graph().input_nodes.append(node)
current_graph().nodes_map[node.unique_name] = node
def add_constant_node(node: TraceNode, output_tensor):
"""Adds a constant node to the current computation graph"""
assert node is not None
actual_tensor = output_tensor
if isinstance(output_tensor, torch.nn.Parameter):
actual_tensor = output_tensor.data
current_graph().tensor_parameter_dict[id(output_tensor)] = weakref.ref(actual_tensor)
node.next_tensors = [actual_tensor]
current_graph().tensor_pre_node_dict[id(output_tensor)] = node.unique_name
current_graph().constant_nodes.append(node)
current_graph().nodes_map[node.unique_name] = node
def add_output_node(node: TraceNode, input_tensors):
"""Adds an output node to the current computation graph"""
assert node is not None
need_idx = True
if not isinstance(input_tensors, (list, tuple)):
input_tensors = [input_tensors]
need_idx = False
node.prev_tensors.extend(input_tensors)
node.rev_index = need_idx
for i, t in enumerate(input_tensors):
node.prev_nodes.append(current_graph().nodes_map[current_graph().tensor_pre_node_dict[id(t)]])
if id(t) in current_graph().tensor_pre_index_dict:
node.prev_indices.append(current_graph().tensor_pre_index_dict[id(t)])
else:
node.prev_indices.append(None)
current_graph().output_nodes.append(node)
current_graph().nodes_map[node.unique_name] = node
def add_forward_node(node: TraceNode, input_tensors, output_tensors):
"""Adds a forward node to the current computation graph"""
assert node is not None
if not isinstance(input_tensors, (list, tuple)):
input_tensors = [input_tensors]
need_idx = True
if not isinstance(output_tensors, (list, tuple)):
output_tensors = [output_tensors]
need_idx = False
flatten_inputs = []
for t in input_tensors:
if isinstance(t, (list, tuple)):
for rt in t:
flatten_inputs.append(rt)
else:
flatten_inputs.append(t)
node.prev_tensors.extend(flatten_inputs)
node.next_tensors.extend(output_tensors)
for i, t in enumerate(flatten_inputs):
assert type(t) in (
torch.dtype,
torch.device,
torch.Size,
torch.Tensor,
torch.nn.Parameter,
torch.nn.quantized.FloatFunctional,
) or isinstance(t, torch.nn.Module), (
f'Input #{i} of {node.unique_name}({node.type()}) should be one of the following type '
' [torch.dtype, torch.device, torch.Size, torch.Tensor,'
f' torch.nn.Parameter,torch.nn.quantized.FloatFunctional, torch.nn.Module], but got {type(t)}'
)
constant_handler(t, node.unique_name, node.full_name())
pre_node_name = current_graph().tensor_pre_node_dict[id(t)]
node.prev_nodes.append(current_graph().nodes_map[pre_node_name])
if id(t) in current_graph().tensor_pre_index_dict:
pre_node_index = current_graph().tensor_pre_index_dict[id(t)]
log.debug(f'propagate pre_index tensor {pre_node_name} {pre_node_index}')
node.prev_indices.append(pre_node_index)
else:
node.prev_indices.append(None)
if isinstance(t, torch.nn.Parameter):
if id(t) in current_graph().tensor_parameter_dict:
node.prev_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
else:
node.prev_tensors[i] = node.prev_tensors[i].data
current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.prev_tensors[i])
for i, t in enumerate(output_tensors):
if isinstance(t, (list, tuple)):
for j, rt in enumerate(t):
assert type(rt) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
f'Output [{i}][{j}] of {node.unique_name}({node.type()}) should be one of the following type '
f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(rt)}'
)
current_graph().tensor_pre_node_dict[id(rt)] = node.unique_name
if need_idx:
log.debug(f'set pre_index tensor {i}, {j}')
current_graph().tensor_pre_index_dict[id(rt)] = [i, j]
else:
assert type(t) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
f'Output #{i} of {node.unique_name}({node.type()}) should be one of the following type '
f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(t)}'
)
current_graph().tensor_pre_node_dict[id(t)] = node.unique_name
if need_idx:
log.debug(f'set pre_index tensor {i}')
current_graph().tensor_pre_index_dict[id(t)] = i
if isinstance(t, torch.nn.Parameter):
if id(t) in current_graph().tensor_parameter_dict:
node.next_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
else:
node.next_tensors[i] = node.next_tensors[i].data
current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.next_tensors[i])
current_graph().forward_nodes.append(node)
current_graph().nodes_map[node.unique_name] = node
@contextlib.contextmanager
def hook_modules(module):
"""Temporarily adds the hooks to a `nn.Module` for tracing"""
hooks = []
def register_submodule_tracer(module):
def _submodule_pre_tracer(module, input):
log.debug(f'pre tracer in _submodule_pre_tracer in {type(module).__name__}')
if lock():
skip_modules.add(weakref.ref(module))
lock(True)
def _submodule_tracer(module, inputs, outputs):
m_ref = weakref.ref(module)
if m_ref in skip_modules:
skip_modules.remove(m_ref)
return None
log.debug(f'tracer in _submodule_tracer in {type(module).__name__}')
node = TraceNode(module)
modified_outputs = noop_handler(node, inputs, outputs)
if modified_outputs is None:
add_forward_node(node, inputs, outputs)
else:
add_forward_node(node, inputs, modified_outputs)
lock(False)
return modified_outputs
module_unique_name = current_graph().module_unique_name_dict[id(module)]
if module_unique_name in current_graph().traced_modules:
log.debug(f"module {module_unique_name} is traced")
return None
related = False
if id(module) in module_constructor_traced:
if (
id(module) in module_constructor_lines
and module_constructor_weakrefs.get(id(module), type(None))() is not None
):
related = True
else:
if type(module) in overridable_modules:
related = True
else:
for m in overridable_modules:
if isinstance(module, m):
related = True
break
if related:
hooks.append(module.register_forward_pre_hook(_submodule_pre_tracer))
hooks.append(module.register_forward_hook(_submodule_tracer))
current_graph().related_modules.append(module_unique_name)
current_graph().traced_modules.append(module_unique_name)
return None
def _model_pre_tracer(module, inputs):
log.debug('pre tracer in _model_pre_tracer')
for i in inputs:
node = TraceNode(TraceFunction("input"))
add_input_node(node, i)
def _model_tracer(module, inputs, outputs):
log.debug('tracer in _model_tracer')
if type(outputs) is torch.Tensor:
node = TraceNode(TraceFunction("output"))
add_output_node(node, outputs)
elif isinstance(outputs, (list, tuple)):
for i in outputs:
if type(i) is torch.Tensor or (
isinstance(i, (list, tuple)) and all((type(x) is torch.Tensor for x in i))
):
node = TraceNode(TraceFunction("output"))
add_output_node(node, i)
else:
log.warning(
"Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
" tuple"
)
elif isinstance(outputs, dict):
for k, v in outputs.items():
if type(v) is torch.Tensor or (
isinstance(v, (list, tuple)) and all((type(x) is torch.Tensor for x in v))
):
node = TraceNode(TraceFunction("output"))
add_output_node(node, v)
else:
log.warning(
"Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
" tuple"
)
else:
log.warning(f'Output type is not supported: {type(outputs).__name__}, try to extract tensors from it')
for k in outputs.__dir__():
v = getattr(outputs, k)
if type(v) is torch.Tensor or (type(v) in (list, tuple) and all((type(x) is torch.Tensor for x in v))):
node = TraceNode(TraceFunction("output"))
add_output_node(node, v)
log.debug('trace: apply register_submodule_tracer')
module.apply(register_submodule_tracer)
log.debug('trace: add hooks')
hooks.append(module.register_forward_pre_hook(_model_pre_tracer))
hooks.append(module.register_forward_hook(_model_tracer))
yield module
for hook in hooks:
hook.remove()
@contextlib.contextmanager
def tracer_context():
"""Basic context manager for tracing"""
yield True
lock(False)
module_constructor_traced.clear()
module_constructor_lines.clear()
module_constructor_weakrefs.clear()
original_values_for_tracked_objects.clear()
skip_modules.clear()
@contextlib.contextmanager
def model_constructor_tracer():
"""Basic context manager for capturing constructors for `nn.Module`"""
with patch_helper(wrap_funcs=False, wrap_creation_funcs=False, wrap_tracking_modules=False):
yield True
@contextlib.contextmanager
def ignore_mod_param_update_warning():
if not mod_param_update_warning_ignore():
mod_param_update_warning_ignore(True)
yield True
mod_param_update_warning_ignore(False)
else:
yield False
@contextlib.contextmanager
def model_tracer():
"""Simple context manager for tracing. Also captures module constructors"""
with tracer_context():
with model_constructor_tracer():
yield True
@contextlib.contextmanager
def construct_trace_graph(
module, dummy_input: torch.Tensor, eliminate_dead_graph: bool, patch_torch_size: bool
) -> 'TraceGraph':
"""Simple context manager for creating a new TraceGraph"""
current_graph(TraceGraph(module, dummy_input, eliminate_dead_graph, patch_torch_size))
yield current_graph.get_value()
current_graph(None)
@contextlib.contextmanager
def override_current_trace_graph(new_graph: 'TraceGraph') -> 'TraceGraph':
"""Simple context manager for creating a new TraceGraph"""
old_graph = current_graph.get_value()
current_graph(new_graph)
yield current_graph.get_value()
current_graph(old_graph)
@contextlib.contextmanager
def import_patcher():
with tracer_context():
with patch_helper(wrap_creation_funcs=False, wrap_funcs=True, wrap_modules=False, wrap_tracking_modules=False):
with no_catch():
yield True
class TraceGraph(object):
"""A data structure for storing a computation graph"""
global_functions: typing.Dict[str, int]
global_nodes: typing.Dict[str, int]
module_unique_name_dict: typing.Dict[int, torch.nn.Module]
module_original_name_dict: typing.Dict[int, str]
traced_modules: typing.List[str]
input_nodes: typing.List[TraceNode]
forward_nodes: typing.List[TraceNode]
output_nodes: typing.List[TraceNode]
constant_nodes: typing.List[TraceNode]
other_init_nodes: typing.List[TraceNode]
nodes_map: typing.Dict[str, TraceNode]
tensor_pre_node_dict: typing.Dict[int, str]
tensor_pre_index_dict: typing.Dict[int, int]
module: torch.nn.Module
dummy_input: torch.Tensor
eliminate_dead_graph: bool
inited: bool
quantized: bool
code: str
def __init__(
self,
module: torch.nn.Module,
dummy_input: torch.Tensor,
eliminate_dead_graph: bool = False,
patch_torch_size: bool = False,
):
# Used for function / node numbering
self.global_functions = {}
self.global_nodes = {}
# Unique name for modules and submodules
self.module_unique_name_dict = {}
self.module_original_name_dict = {}
# Unique name for buffers and parameters
self.parameter_original_name_dict = {}
self.parameter_unique_name_dict = {}
# Mapping between parameters and their parent modules
self.parameter_module_dict = {}
# Recording traced modules
self.traced_modules = []
self.related_modules = []
# Recording nodes
self.input_nodes = []
self.forward_nodes = []
self.output_nodes = []
self.constant_nodes = []
self.other_init_nodes = []
# Node <-> name mapping
self.nodes_map = {}
# Recording the previous node of the tensors
self.tensor_pre_node_dict = {}
# Recording the previous index of the tensors
self.tensor_pre_index_dict = {}
# Recording the tensor object of the parameters
self.tensor_parameter_dict = {}
# Input module
if isinstance(module, DataParallel) or isinstance(module, DistributedDataParallel):
log.error(
'You are tracing a parallel module, which is unsupported. Please pass in a raw model using `.module`.'
)
assert False
else:
self.module = module
# Input data
self.dummy_input = dummy_input
# Whether to keep inactive nodes after tracing
self.eliminate_dead_graph = eliminate_dead_graph
# Let's give the module and its children a name.
self.__tag_nodes()
# Whether the tracing is completed or not
self.inited = False
# Whether the graph is rewrited to be a quantized one
self.quantized = False
# Generated code
self.code = "None"
# Used namespaces
self.used_namespaces = set()
# Tracer options
self.patch_torch_size = patch_torch_size
def all_nodes(self) -> typing.List[TraceNode]:
"""Returns all the nodes in a computation graph during forward process"""
return self.input_nodes + self.forward_nodes + self.output_nodes + self.constant_nodes
def all_tensors(self) -> typing.List[torch.Tensor]:
tensors = dict()
for n in self.all_nodes():
for t in n.prev_tensors + n.next_tensors:
if isinstance(t, torch.Tensor):
tensors[id(t)] = t
return list(tensors.values())
def __tag_nodes(self) -> None:
"""Gives the modules and the submodules a unique name"""
# Tag submodules
for n, m in self.module.named_modules():
n = n.replace(".", "_")
n = n.replace("-", "_")
n = n.replace("/", "_")
n = n.replace(";", "_")
n = 'module_' + n if n.isnumeric() else n
self.module_unique_name_dict[id(m)] = n
# Tag the module itself
self.module_unique_name_dict[id(self.module)] = type(self.module).__name__
for n, p in self.module.named_parameters():
n = n.replace(".", "_")
n = n.replace("-", "_")
n = n.replace("/", "_")
n = n.replace(";", "_")
n = 'param_' + n if n.isnumeric() else n
self.parameter_unique_name_dict[id(p)] = n
for n, b in self.module.named_buffers():
n = n.replace(".", "_")
n = n.replace("-", "_")
n = n.replace("/", "_")
n = n.replace(";", "_")
n = 'buffer_' + n if n.isnumeric() else n
self.parameter_unique_name_dict[id(b)] = n
q = queue.Queue()
q.put((self.module, ''))
while not q.empty():
m, p = q.get()
if isinstance(m, nn.Module):
self.module_original_name_dict[id(m)] = p
for n, c in m.named_children():
if isinstance(m, (nn.Sequential, nn.ModuleList)) and n.isnumeric():
c_p = f'{p}[{n}]'
elif isinstance(m, nn.ModuleDict):
c_p = f'{p}["{n}"]'
else:
if len(p) > 0:
if '.' in n or '-' in n:
c_p = f'{p}.get_submodule("{n}")'
else:
c_p = f'{p}.{n}'
else:
c_p = n
q.put((c, c_p))
for n, c in m.named_parameters(recurse=False):
self.parameter_module_dict[id(c)] = id(m)
if len(p) > 0:
if '.' in n or '-' in n:
c_p = f'{p}.get_parameter("{n}")'
else:
c_p = f'{p}.{n}'
else:
c_p = n
q.put((c, c_p))
for n, c in m.named_buffers(recurse=False):
self.parameter_module_dict[id(c)] = id(m)
if len(p) > 0:
if '.' in n or '-' in n:
c_p = f'{p}.get_buffer("{n}")'
else:
c_p = f'{p}.{n}'
else:
c_p = n
q.put((c, c_p))
else:
self.parameter_original_name_dict[id(m)] = p
def __active_detection(self, node: TraceNode):
"""Detects whether the node is active or not"""
q = queue.Queue()
q.put(node)
while not q.empty():
node = q.get()
if not node.active:
node.active = True
for i in node.prev_nodes:
q.put(i)
def init(self) -> None:
"""Builds a computation graph"""
if self.inited:
return
with no_catch():
device = get_module_device(self.module)
with self.__numbering_context():
if type(self.dummy_input) is torch.Tensor:
actual_input = [self.dummy_input]
elif isinstance(self.dummy_input, (tuple, list)):
actual_input = list(self.dummy_input)
else:
log.error(f'Unsupported type {type(self.dummy_input)} for dummy input')
assert False
for i in range(len(actual_input)):
dummy_input = actual_input[i]
if type(dummy_input) is torch.Tensor:
new_input = dummy_input.detach().clone()
if new_input.is_floating_point():
new_input.requires_grad = True
if new_input.device != device:
new_input = new_input.to(device=device)
actual_input[i] = new_input
original_state_dict = copy.deepcopy(self.module.state_dict())
with patch_helper(wrap_tracking_modules=self.patch_torch_size):
with hook_modules(self.module):
self.module(*actual_input)
self.module.load_state_dict(original_state_dict)
if self.eliminate_dead_graph:
self.eliminate_dead_graph_pass()
self.recompute_forward_order()
self.inited = True
def eliminate_dead_graph_pass(self):
for n in self.input_nodes + self.forward_nodes + self.output_nodes:
n.active = False
for i in self.output_nodes:
self.__active_detection(i)
active_input_nodes = [i for i in self.input_nodes if i.active]
active_forward_nodes = [i for i in self.forward_nodes if i.active]
active_constant_nodes = dict()
for n in self.forward_nodes + self.output_nodes:
if n.active:
for pn in n.prev_nodes:
if isinstance(pn.module, ConstantNode):
active_constant_nodes[pn.unique_name] = pn
self.input_nodes = active_input_nodes
self.forward_nodes = active_forward_nodes
self.constant_nodes = list(active_constant_nodes.values())
def add_inputs_for_tensors(self, input_tensors):
"""Add input nodes for specific tensors"""
with self.__numbering_context():
self.global_functions['input'] = len(self.input_nodes) - 1
for i, t in enumerate(input_tensors):
node = self.nodes_map[self.tensor_pre_node_dict[id(t)]]
with override_current_trace_graph(self):
new_node = TraceNode(TraceFunction("input"))
add_input_node(new_node, t)
node.prev_nodes.append(new_node)
self.recompute_forward_order()
def add_inputs_for_tensors_in_node(self, input_node, input_tensors):
"""Add input nodes for specific tensors with a node"""
with self.__numbering_context():
self.global_functions['input'] = len(self.input_nodes) - 1
for i, t in enumerate(input_tensors):
input_node.prev_tensors.append(t)
input_node.prev_indices.append(None)
with override_current_trace_graph(self):
new_node = TraceNode(TraceFunction("input"))
add_input_node(new_node, t)
input_node.prev_nodes.append(new_node)
self.recompute_forward_order()
def add_outputs_for_tensors(self, output_tensors):
"""Add output nodes for specific tensors"""
with self.__numbering_context():
self.global_functions['output'] = len(self.output_nodes) - 1
for i, t in enumerate(output_tensors):
node = self.nodes_map[self.tensor_pre_node_dict[id(t)]]
with override_current_trace_graph(self):
new_node = TraceNode(TraceFunction("output"))
add_output_node(new_node, t)
node.next_nodes.append(new_node)
self.recompute_forward_order()
def add_state_input_outputs(self):
for node in self.forward_nodes:
if not isinstance(node.module, (nn.LSTM, nn.GRU, nn.RNN)):
continue
if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0:
continue
input_tensor = node.prev_tensors[0]
max_batch_size = input_tensor.size(0) if node.module.batch_first else input_tensor.size(1)
num_directions = 2 if node.module.bidirectional else 1
real_hidden_size = (
node.module.proj_size if getattr(node.module, 'proj_size', 0) > 0 else node.module.hidden_size
)
hidden_shape = (node.module.num_layers * num_directions, max_batch_size, real_hidden_size)
if len(node.prev_nodes) == 1:
hx = torch.zeros(
hidden_shape,
dtype=input_tensor.dtype,
device=input_tensor.device,
)
if isinstance(node.module, nn.LSTM):
cx = torch.zeros(
hidden_shape,
dtype=input_tensor.dtype,
device=input_tensor.device,
)
self.add_inputs_for_tensors_in_node(node, [hx, cx])
indices = (-2, -1)
else:
self.add_inputs_for_tensors_in_node(node, [hx])
indices = (-1,)
for i in indices:
dtype_str = str(self.input_nodes[i].next_tensors[0].dtype).replace('torch.', '')
shape_str = str([x for x in self.input_nodes[i].next_tensors[0].shape])
log.warning(f'Input added: {self.input_nodes[i].unique_name}: {dtype_str}{shape_str}')
if len(node.next_nodes) == 1:
if isinstance(node.module, nn.LSTM):
self.add_outputs_for_tensors(node.next_tensors[1])
indices = (-2, -1)
else:
self.add_outputs_for_tensors([node.next_tensors[1]])
indices = (-1,)
for i in indices:
dtype_str = str(self.output_nodes[i].prev_tensors[0].dtype).replace('torch.', '')
shape_str = str([x for x in self.output_nodes[i].prev_tensors[0].shape])
log.warning(f'Output added: {self.output_nodes[i].unique_name}: {dtype_str}{shape_str}')
def reset_input_output_for_graph(self, input_names: typing.List[str], output_names: typing.List[str]):
"""Extract a subgraph from the original computation graph"""
assert len(set(input_names) & set(output_names)) == 0, "A node cannot be both input and output"
input_nodes = []
output_nodes = []
for name in input_names + output_names:
node = self.nodes_map.get(name, None)
assert node is not None, f"{name} is not a node in TraceGraph"
assert node in self.forward_nodes, f"{name} is not a forward node in TraceGraph"
if name in input_names:
input_nodes.append(node)
else:
output_nodes.append(node)
self.assert_is_subgraph(input_nodes, output_nodes)
for node in self.input_nodes + self.output_nodes:
del self.nodes_map[node.unique_name]
self.input_nodes.clear()
self.output_nodes.clear()
with self.__numbering_context():
for node in input_nodes:
node.prev_nodes.clear()
for i, t in enumerate(node.prev_tensors):
with override_current_trace_graph(self):
new_node = TraceNode(TraceFunction("input"))
add_input_node(new_node, t)
node.prev_nodes.append(new_node)
for node in output_nodes:
node.next_nodes.clear()
for i, t in enumerate(node.next_tensors):
if type(t) is torch.Tensor or (
isinstance(t, (list, tuple)) and all((type(x) is torch.Tensor for x in t))
):
with override_current_trace_graph(self):
new_node = TraceNode(TraceFunction("output"))
add_output_node(new_node, t)
node.next_nodes.append(new_node)
else:
log.warning(
"Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
" tuple"
)
self.eliminate_dead_graph_pass()
self.recompute_forward_order()
def assert_is_subgraph(self, input_nodes, output_nodes):
q = queue.Queue()
for node in output_nodes:
q.put(node)
start_nodes = {node.unique_name: node for node in input_nodes}
visited = set()
actual_starts = dict()
while not q.empty():
node = q.get()
if node.unique_name in visited:
continue
visited.add(node.unique_name)
if node.unique_name in start_nodes:
continue
if len(node.prev_nodes) == 0:
if not len(node.next_nodes) == 0:
actual_starts[node.unique_name] = node
continue
for prev_node in node.prev_nodes:
q.put(prev_node)
missing_inputs = set(actual_starts) - set(start_nodes)
assert len(missing_inputs) == 0, f"Not a subgraph, missing inputs: {missing_inputs}"
@contextlib.contextmanager
def __numbering_context(self):
"""A simple context manager for numbering nodes"""
yield True
self.global_functions.clear()
self.global_nodes.clear()
def __gen_init_code(self) -> str:
"""Generates the code for the init function for a `nn.Module`"""
generated_node = []
lines = []
mod_ids = []
mod_cache_dict = dict()
for node in self.constant_nodes + self.forward_nodes + self.other_init_nodes:
if node.unique_name in generated_node or id(node.module) in mod_ids:
log.info(f"skip dumplicate node code gen {node.unique_name}")
continue
generated_node.append(node.unique_name)
mod_ids.append(id(node.module))
if id(node.module) in module_constructor_lines:
root_ns = qualified_name(node.type()).split('.')[0]
if isinstance(node.module, TraceFunction) and (
node.module.full_name.startswith('torch.')
or node.module.full_name.startswith('self.')
or '.' not in node.module.full_name
):
continue
self.used_namespaces.add(root_ns)
orig_constructor_line = module_constructor_lines[id(node.module)]
line = f' self.{node.unique_name} = {orig_constructor_line}'
lines.append(line)
elif type(node.module) is ConstantNode:
# Parameter generation
self.used_namespaces.add('torch')
requires_grad_prop = ''
if node.module.is_parameter != node.module.requires_grad:
requires_grad_prop = f', requires_grad={node.module.requires_grad}'
if node.module.is_parameter:
line = (
f' self.register_parameter("{node.unique_name}",'
f' torch.nn.Parameter(torch.empty({node.module.shape}, dtype={node.module.dtype})'
f'{requires_grad_prop}))'
)
elif node.module.is_persistent:
line = (
f' self.register_buffer("{node.unique_name}", torch.empty({node.module.shape},'
f' dtype={node.module.dtype}{requires_grad_prop}))'
)
else:
line = (
f' self.register_buffer("{node.unique_name}", torch.tensor({node.module.data_str},'
f' dtype={node.module.dtype}{requires_grad_prop}), persistent=False)'
)
lines.append(line)
elif type(node.module) is not TraceFunction:
# Generate the module even if the constructor is not caught
log.info(
f'the constructor of the module {node.unique_name} of type {type(node.module).__name__} is not'
' traced, trying the experimental way'
)
root_ns = qualified_name(node.type()).split('.')[0]
self.used_namespaces.add(root_ns)
orig_constructor_line, mod_cache = gen_module_constrctor_line(node.module, mod_cache_dict)
line = f' self.{node.unique_name} = {orig_constructor_line}'
lines.append(line)
mod_cache_dict.update(mod_cache)
block = "\n".join(lines)
return block
def __gen_forward_code(self, inplace=False) -> str:
"""Generates the code for the forward function for a `nn.Module`"""
lines = [f" def forward(self, {','.join([i.unique_name for i in self.input_nodes])}):"]
mod_name_dict = {}
for node in self.forward_nodes:
output = ", ".join([node.unique_name])
param = ", ".join([node.prev_node_unique_name(i, inplace) for i in range(len(node.prev_nodes))])
if type(node.module) is TraceFunction:
full_name = node.full_name()
if not full_name.startswith('torch.') and not full_name.startswith('self.') and '.' in full_name:
ns = '.'.join(full_name.split('.')[:-1])
self.used_namespaces.add(ns)
first_arg = None
if node.is_class():
first_arg = node.prev_node_unique_name(0, inplace)
if node.type().startswith('__i') and node.type().endswith('__'):
inner_op = node.module.func_type[3:-2]
if inner_op in SPECIAL_OPERATORS:
node.module.func_type = f'__{inner_op}__'
parts = node.module.full_name.split('.')[:-1] + [node.module.func_type]
node.module.full_name = '.'.join(parts)
if first_arg is not None:
alias = first_arg
else:
alias = node.module.get_tensor_name(0, inplace)
node.module.add_alias(alias)
aliases = node.module.get_aliases()
prefix = ''
if aliases is not None:
prefix = ''.join([f'{x} = ' for x in aliases])
line = f" {prefix}{output} = {node.module.extra_expr(first=first_arg, original=inplace)}"
else:
if inplace:
mod_name = node.original_name
else:
mod_name_dict.setdefault(node.module, node.unique_name)
mod_name = mod_name_dict[node.module]
if len(node.prev_tensors) == 0 and len(node.next_tensors) == 0:
continue
if node.type() is nn.LSTM and len(node.prev_nodes) == 3 and len(node.prev_tensors) == 3:
first_arg = node.prev_node_unique_name(0)
param = ", ".join([node.prev_node_unique_name(i) for i in range(1, len(node.prev_nodes))])
line = f" {output} = self.{mod_name}({first_arg}, ({param}))"
else:
line = f" {output} = self.{mod_name}({param})"
lines.append(line)
for pn in {pn.unique_name: pn for pn in node.prev_nodes}.values():
if node.forward_order == max([n.forward_order for n in pn.next_nodes]):
if pn.type() not in (ConstantNode, torch.nn.quantized.FloatFunctional):
lines.append(f" {pn.unique_name} = None")
def _gen_output_node(node):
if node.rev_index:
return f'[{", ".join([node.prev_node_unique_name(i) for i in range(len(node.prev_nodes))])}]'
else:
return node.prev_node_unique_name(0)
lines.append(f" return {', '.join([_gen_output_node(i) for i in self.output_nodes])}")
block = "\n".join(lines)
return block
def __gen_import_code(self) -> str:
"""Generates the code for the import section for a `nn.Module`"""
# TODO: Selective module importing
import_block = """import torch\nimport torch.nn\nimport torch.functional\nimport torch.nn.functional"""
if self.quantized is True:
import_block += '\nimport torch.quantization\nimport torch.nn.quantized'
additional_ns = sorted(self.used_namespaces - set(['torch']))
import_block += ''.join([f'\nimport {ns}' for ns in additional_ns])
return import_block
def __gen_input_code(self) -> str:
"""Generates the code for the input section for the code to invoke `nn.Module`"""
input_block = ""
for i, node in enumerate(self.input_nodes):
shape = ", ".join((str(i) for i in node.next_tensors[0].shape))
dtype = node.next_tensors[0].dtype
if i != 0:
input_block += "\n"
input_block += f" dummy_input_{i} = torch.ones(({shape}), dtype={dtype})"
return input_block
def recompute_forward_order(self):
forward_order = 0
for n in self.input_nodes + self.forward_nodes + self.output_nodes:
n.forward_order = forward_order
forward_order += 1
for prev in n.prev_nodes:
if n not in prev.next_nodes:
prev.next_nodes.append(n)
def generate_code(
self,
output_script_path: typing.Optional[str],
output_weight_path: typing.Optional[str],
model_name: str = 'DefaultModel',
check: bool = False,
) -> bool:
"""The main function for code generation"""
output_paths = (output_script_path, output_weight_path)
for output_path in output_paths:
if not output_path:
continue
output_dir = os.path.dirname(output_path)
if output_dir != '' and not os.path.exists(output_dir):
os.makedirs(output_dir)
DummyModel = type('DummyModel', (torch.nn.Module,), {})
dummy_model = DummyModel()
mod_ids = set()
for node in self.forward_nodes:
if id(node.module) not in mod_ids:
setattr(dummy_model, node.unique_name, node.module)
mod_ids.add(id(node.module))
for node in self.constant_nodes:
if node.module.is_parameter or node.module.is_persistent:
dtype = getattr(torch, node.module.dtype.split('.')[-1])
new_tensor = torch.tensor(node.module.data, dtype=dtype)
if node.module.is_parameter:
weight = torch.nn.Parameter(new_tensor)
dummy_model.register_parameter(node.unique_name, weight)
else:
dummy_model.register_buffer(node.unique_name, new_tensor)
if output_weight_path:
torch.save(dummy_model.state_dict(), output_weight_path)
output_weight_path_str = output_weight_path.replace('\\', '\\\\')
init_block = self.__gen_init_code()
forward_block = self.__gen_forward_code()
import_block = self.__gen_import_code()
input_block = self.__gen_input_code()
context = {
"import_block": import_block,
"init_block": init_block,
"forward_block": forward_block,
"name_block": model_name,
"load_weight_block": ""
if output_weight_path is None
else f" model.load_state_dict(torch.load('{output_weight_path_str}'))",
"input_block": input_block,
"input_names": ", ".join([f"dummy_input_{i}" for i in range(len(self.input_nodes))]),
}
code = MODULE_TEMPLATE % context
if output_script_path:
if os.path.exists(output_script_path):
os.remove(output_script_path)
with io.open(output_script_path, 'w') as f:
f.write(code)
if check:
training_state = self.module.training
valid = True
with torch.no_grad():
self.module.eval()
original_outputs = tensors2ndarray(self.module(*self.dummy_input))
new_module = import_from_path(f'tracer_check.{model_name}', output_script_path, model_name)()
new_module.eval()
if output_weight_path:
new_module.load_state_dict(torch.load(output_weight_path))
new_outputs = tensors2ndarray(new_module(*self.dummy_input))
for i in range(len(original_outputs)):
output = original_outputs[i]
new_output = new_outputs[i]
if not np.allclose(output, new_output):
log.warning(f"[WARNING] Output {i} is not equal.")
valid = False
if training_state:
self.module.train()
return valid
else:
return True
def inplace_commit(self, show_code: bool = False):
"""Commit the changes to the TraceGraph and applies it to the model"""
forward_block = self.__gen_forward_code(True)
import_block = self.__gen_import_code()
code = re.sub(r'^ ', '', forward_block, flags=re.MULTILINE)
if show_code:
log.warning(f'The new forward function for the model:\n{code}')
tmp_ns = {}
exec(import_block, tmp_ns)
exec(code, tmp_ns)
new_func = tmp_ns['forward']
setattr(self.module, 'forward', types.MethodType(new_func, self.module))
for node in self.forward_nodes + self.other_init_nodes:
if isinstance(node.module, nn.Module) and node.unique_name not in self.related_modules:
setattr(self.module, node.original_name, node.module)
for node in self.constant_nodes:
if not node.module.inplace:
setattr(self.module, node.original_name, node.next_tensors[0])
return self.module
def update_submodule_in_nodes_from_predicate(
self,
nodes: typing.List[TraceNode],
module_gen_predicate: typing.Callable[[nn.Module], nn.Module],
inplace: bool = False,
):
"""update a submodule from the nodes using the predicate given"""
for node in nodes:
module = node.module
new_module = module_gen_predicate(module)
self.update_submodule_in_node(node, new_module, inplace)
def get_submodule_with_parent_from_name(self, module_name: str, inplace: bool = False):
"""Gets the submodule with its parent using the name given"""
if inplace:
module_name = re.sub('get_submodule\\("(.*?)"\\)', '\\1', module_name)
module_name = re.sub('\\[("|)(.*?)("|)\\]', '.\\2', module_name)
module_name_parts = module_name.split('.')
cur_obj = self.module
last_obj = None
for ns in module_name_parts:
last_obj = cur_obj
if type(cur_obj) is nn.ModuleList:
cur_obj = cur_obj[int(ns)]
elif type(cur_obj) is nn.ModuleDict:
cur_obj = cur_obj[ns]
else:
cur_obj = getattr(cur_obj, ns)
return cur_obj, last_obj
def update_submodule_in_node(self, node: TraceNode, module: nn.Module, inplace: bool = False):
"""update a submodule from the nodes using the module given"""
module_name = self.module_original_name_dict[id(node.module)]
if inplace:
module_name = re.sub('get_submodule\\("(.*?)"\\)', '\\1', module_name)
module_name = re.sub('\\[("|)(.*?)("|)\\]', '.\\2', module_name)
module_name_parts = module_name.split('.')
cur_obj = self.module
for ns in module_name_parts[:-1]:
if type(cur_obj) is nn.ModuleList:
cur_obj = cur_obj[int(ns)]
elif type(cur_obj) is nn.ModuleDict:
cur_obj = cur_obj[ns]
else:
cur_obj = getattr(cur_obj, ns)
ns = module_name_parts[-1]
new_obj = module
if type(cur_obj) is nn.ModuleList:
cur_obj[int(ns)] = new_obj
elif type(cur_obj) is nn.ModuleDict:
cur_obj[ns] = new_obj
else:
setattr(cur_obj, ns, new_obj)
def filter_forward_nodes(self, predicate, custom_data=None, reverse=False) -> typing.List[TraceNode]:
"""A utility function for filtering forward nodes"""
nodes = []
iter_nodes = self.forward_nodes
if reverse:
iter_nodes = reversed(iter_nodes)
for node in iter_nodes:
if predicate(node, custom_data):
nodes.append(node)
return nodes
def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None):
"""Insert a module or an existing node after a node in the computation graph"""
# Create a new node and connects it to the next node/tensors
if type(module) is not TraceNode:
new_node = TraceNode(module, cur_graph=self)
if node in self.input_nodes or node in self.constant_nodes:
self.forward_nodes.insert(0, new_node)
elif node in self.output_nodes:
log.error('You cannot insert a node after output nodes')
assert False
else:
idx = self.forward_nodes.index(node)
self.forward_nodes.insert(idx + 1, new_node)
self.nodes_map[new_node.unique_name] = new_node
else:
new_node = module
is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
new_node.prev_nodes.append(node)
new_node.next_nodes.extend(node.next_nodes)
if next_tensors is None:
next_tensors = [None] * len(node.next_tensors)
for t, new_t in zip(node.next_tensors, next_tensors):
if new_t is None:
new_t = t.clone()
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
new_node.prev_tensors.append(t)
new_node.next_tensors.append(new_t)
new_node.prev_indices.append(None)
# Make input/constant nodes connects to the new node
node.next_nodes.clear()
node.next_nodes.append(new_node)
# Connect the next nodes to the new node
tensor_replace_dict = dict(zip(new_node.prev_tensors, new_node.next_tensors))
for next_node in new_node.next_nodes:
is_next_constant_node = type(next_node) in (ConstantNode, torch.nn.quantized.FloatFunctional)
for i, n in enumerate(next_node.prev_nodes):
if n == node:
next_node.prev_nodes[i] = new_node
# Make sure the data is writable
if isinstance(next_node.prev_tensors, tuple):
next_node.prev_tensors = list(next_node.prev_tensors)
updated_indices = []
for i, t in enumerate(next_node.prev_tensors):
if t in tensor_replace_dict:
next_node.prev_tensors[i] = tensor_replace_dict[t]
updated_indices.append(next_node.prev_indices[i])
# Since the function calls are rendered beforehand,
# we need to change them as well.
if type(next_node.module) is TraceFunction:
if next_node.module.args_string is not None:
for idx in updated_indices:
old_unique_name = tensor_name_from_parts(node.unique_name, idx, is_constant_node)
new_unique_name = tensor_name_from_parts(new_node.unique_name, idx, is_next_constant_node)
next_node.module.replace_tensor_name(old_unique_name, new_unique_name)
next_node.module.update_args_string()
def insert_new_after(
self,
node,
module_or_func,
prev_tensors: typing.List[torch.Tensor],
prev_indices: typing.List[torch.Tensor],
next_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
before_node: typing.Optional[TraceNode] = None,
):
assert type(module_or_func) is not TraceNode
new_node = TraceNode(module_or_func, cur_graph=self)
if next_tensors is None:
next_tensors = [t.clone() for t in prev_tensors]
for new_t, new_i in zip(next_tensors, prev_indices):
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
if new_i is not None:
self.tensor_pre_index_dict[id(new_t)] = new_i
new_node.prev_tensors.extend(prev_tensors)
new_node.next_tensors.extend(next_tensors)
new_node.prev_indices.extend(prev_indices)
new_node.prev_nodes.append(node)
node.next_nodes.append(new_node)
if before_node is not None:
idx = self.forward_nodes.index(before_node)
self.forward_nodes.insert(idx, new_node)
else:
self.forward_nodes.append(new_node)
self.nodes_map[new_node.unique_name] = new_node
return new_node
def insert_between(
self,
prev_node: TraceNode,
next_node: TraceNode,
module,
next_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
move_idx: bool = False,
tensor_ptrs: typing.Optional[typing.Set[int]] = None,
):
"""Insert a module or an existing node between two nodes in the computation graph"""
# Create a new node and connects it to the previous node/tensors
old_unique_name = prev_node.unique_name
is_constant_node = type(prev_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
if type(module) is not TraceNode:
new_node = TraceNode(module, cur_graph=self)
if prev_node not in next_node.prev_nodes or next_node not in prev_node.next_nodes:
log.error('You cannot insert a node between two nodes that is not connected')
assert False
idx = self.forward_nodes.index(next_node)
self.forward_nodes.insert(idx, new_node)
self.nodes_map[new_node.unique_name] = new_node
new_node.prev_nodes.append(prev_node)
new_node.next_nodes.append(next_node)
else:
new_node = module
if prev_node not in new_node.prev_nodes:
new_node.prev_nodes.append(prev_node)
if next_node not in new_node.next_nodes:
new_node.next_nodes.append(next_node)
if next_tensors is None:
next_tensors = list(new_node.next_tensors)
new_node.prev_tensors.clear()
new_node.next_tensors.clear()
new_node.prev_indices.clear()
is_new_constant_node = type(new_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
# Gather tensors from previous nodes
prev_tensors = []
prev_indices = []
remaining_count = 0
for pt, pidx in zip(next_node.prev_tensors, next_node.prev_indices):
skip_node = tensor_ptrs is not None and id(pt) not in tensor_ptrs
for nt in prev_node.next_tensors:
if id(pt) == id(nt):
if not skip_node:
prev_tensors.append(pt)
prev_indices.append(pidx)
else:
remaining_count += 1
break
elif isinstance(nt, (list, tuple)):
for i, ntt in enumerate(nt):
if id(ntt) == id(pt):
if not skip_node:
prev_tensors.append(pt)
prev_indices.append(pidx)
else:
remaining_count += 1
break
if next_tensors is None:
next_tensors = [None] * len(prev_tensors)
assert len(next_tensors) > 0, f'{prev_node.unique_name} and {next_node.unique_name} has no common tensors'
for idx, (t, new_t, pidx) in enumerate(zip(prev_tensors, next_tensors, prev_indices)):
if new_t is None:
new_t = t.clone()
next_tensors[idx] = new_t
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
new_node.prev_tensors.append(t)
new_node.next_tensors.append(new_t)
new_node.prev_indices.append(pidx if move_idx else None)
# Make output nodes connects to the new node
for idx in range(len(next_node.prev_nodes) - remaining_count):
if next_node.prev_nodes[idx] == prev_node:
next_node.prev_nodes[idx] = new_node
# Update tensors in output nodes
index_mapping = []
for idx, t in enumerate(next_node.prev_tensors):
for pt, nt in zip(prev_tensors, next_tensors):
if id(t) == id(pt):
next_node.prev_tensors[idx] = nt
old_index = next_node.prev_indices[idx]
if move_idx:
next_node.prev_indices[idx] = None
new_index = next_node.prev_indices[idx]
index_mapping.append((old_index, new_index))
break
# Connect the previous nodes to the new node
for prev_node in new_node.prev_nodes:
if remaining_count > 0:
if new_node not in prev_node.next_nodes:
prev_node.next_tensors.append(new_node)
else:
for i, n in enumerate(prev_node.next_nodes):
if n == next_node:
prev_node.next_nodes[i] = new_node
break
# Update previous node name for next nodes (TraceFunction)
if type(next_node.module) is TraceFunction:
for old_idx, new_idx in index_mapping:
prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node)
next_unique_name = tensor_name_from_parts(new_node.unique_name, new_idx, is_new_constant_node)
log.debug(f'tensor rename: {prev_unique_name} -> {next_unique_name}')
next_node.module.replace_tensor_name(prev_unique_name, next_unique_name)
next_node.module.update_args_string()
def insert_before(
self,
node: TraceNode,
module,
next_tensors: typing.Optional[typing.List[torch.Tensor]] = None,
move_idx: bool = False,
next_indices: typing.Optional[typing.List[int]] = None,
):
"""Insert a module or an existing node before a node in the computation graph"""
# Create a new node and connects it to the previous node/tensors
if type(module) is not TraceNode:
if not isinstance(module, (tuple, list)):
modules = [module]
rev_mode = False
else:
if not node.rev_index:
log.error('You can only insert nodes with a list modules when node.rev_index=True')
assert False
if len(module) != len(node.prev_nodes):
log.error(f'The number of the modules provided is wrong, expected: {len(node.prev_nodes)}')
assert False
modules = module
rev_mode = True
new_nodes: typing.List[TraceNode] = []
for module in modules:
new_node = TraceNode(module, cur_graph=self)
new_nodes.append(new_node)
else:
new_nodes = [module]
if node in self.input_nodes or node in self.constant_nodes:
log.error('You cannot insert a node before input/constant nodes')
assert False
elif node in self.output_nodes:
for new_node in new_nodes:
self.forward_nodes.append(new_node)
else:
for new_node in new_nodes:
idx = self.forward_nodes.index(node)
self.forward_nodes.insert(idx, new_node)
for idx, new_node in enumerate(new_nodes):
self.nodes_map[new_node.unique_name] = new_node
if not rev_mode:
new_node.prev_nodes.extend(node.prev_nodes)
else:
new_node.prev_nodes.append(node.prev_nodes[idx])
new_node.next_nodes.append(node)
index_mapping = []
for idx, new_node in enumerate(new_nodes):
if not rev_mode:
prev_tensors = node.prev_tensors
prev_indices = node.prev_indices
if move_idx:
node.prev_indices = [None] * len(node.prev_indices)
else:
prev_tensors = node.prev_tensors[idx : idx + 1]
prev_indices = node.prev_indices[idx : idx + 1]
if move_idx:
node.prev_indices[idx] = None
if next_tensors is None:
next_tensors = [None] * len(prev_tensors)
if next_indices is None:
next_indices = [None] * len(prev_tensors)
for t, new_t, ind, n_ind in zip(prev_tensors, next_tensors, prev_indices, next_indices):
if new_t is None:
new_t = t.clone()
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
new_node.prev_tensors.append(t)
new_node.next_tensors.append(new_t)
new_node.prev_indices.append(ind if move_idx else n_ind)
index_mapping.append((ind, new_node.prev_indices[-1]))
# Make output nodes connects to the new node
node.prev_nodes.clear()
node.prev_nodes.extend(new_nodes)
node.prev_tensors.clear()
for new_node in new_nodes:
node.prev_tensors.extend(new_node.next_tensors)
# Connect the previous nodes to the new node
if rev_mode:
for prev_node in new_node.prev_nodes:
idx = None
for i, n in enumerate(prev_node.next_nodes):
if n == node:
idx = i
break
if idx is not None:
for i, new_node in enumerate(new_nodes):
prev_node.next_nodes.insert(idx + 1 + i, new_node)
prev_node.next_nodes.pop(idx)
else:
for new_node in new_nodes:
for prev_node in new_node.prev_nodes:
for i, n in enumerate(prev_node.next_nodes):
if n == node:
prev_node.next_nodes[i] = new_node
break
# Update previous node name for next nodes (TraceFunction)
if type(node.module) is TraceFunction and node not in self.output_nodes:
new_node = new_nodes[0]
for i in range(len(new_node.prev_nodes)):
old_unique_name = new_node.prev_nodes[i].unique_name
is_constant_node = type(new_node.prev_nodes[i].module) in (
ConstantNode,
torch.nn.quantized.FloatFunctional,
)
is_next_constant_node = type(new_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
old_idx, new_idx = index_mapping[i]
prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node=is_constant_node)
next_unique_name = tensor_name_from_parts(
new_node.unique_name, new_idx, is_constant_node=is_next_constant_node
)
log.debug(f'node rename: {prev_unique_name} -> {next_unique_name}')
node.module.replace_tensor_name(prev_unique_name, next_unique_name)
node.module.update_args_string()
def __call__(self, *args, **kwargs):
"""Calls the function with a list of tensor inputs"""
# Prepare inputs
tensors_map = {}
i = 0
for t in args:
if isinstance(t, (tuple, list)):
for rt in t:
tensors_map[(f'input_{i}_f', None)] = rt
i += 1
else:
tensors_map[(f'input_{i}_f', None)] = t
i += 1
# Sort nodes
sorted_nodes = sorted(self.forward_nodes, key=lambda x: x.forward_order)
for node in sorted_nodes:
# Basic info
op_kind = node.kind()
op_kind = op_kind.__name__ if isinstance(op_kind, type) else op_kind
log.debug(f'{node.unique_name}({op_kind}):')
log.debug(' Inputs:')
# TODO: Support the following case
assert len(node.prev_nodes) == len(node.prev_tensors), "not supported"
# Collect input tensors
prev_tensors = []
for i, pn in enumerate(node.prev_nodes):
pn_name = pn.unique_name
pn_idx = node.prev_indices[i]
log.debug(f' {pn_name} {pn_idx}')
if isinstance(pn.module, ConstantNode):
prev_tensors.append(node.prev_tensors[i])
elif (pn_name, pn_idx) in tensors_map:
prev_tensors.append(tensors_map[(pn_name, pn_idx)])
else:
assert False, f"({pn_name}, {pn_idx}) is not found in tensor dict"
node.prev_tensors = prev_tensors
# Calculate output
if isinstance(node.module, nn.Module):
output = node.module(*prev_tensors)
else:
output = node.module(prev_tensors)
# Shape handling
if node.type() in ('shape', 'size'):
if len(prev_tensors) == 1:
output = torch.tensor(output).unbind(0)
else:
output = torch.tensor(output)
log.debug('')
log.debug(' Outputs:')
# Updating output tensors
if isinstance(output, (list, tuple)):
for j, t in enumerate(output):
tensors_map[(node.unique_name, j)] = t
node.next_tensors[j] = t
else:
tensors_map[(node.unique_name, None)] = output
node.next_tensors[0] = output
for i in range(len(node.next_tensors)):
log.debug(f' {i}')
log.debug('')
def replace_node_module(self, node: TraceNode, module: torch.nn.Module) -> None:
"""Replaces a module in a node with another"""
# Update unique name for node
old_unique_name = node.unique_name
is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
is_new_constant_node = type(module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
node.unique_name = self.module_unique_name_dict[id(module)]
# Drop original module constructor line
if id(node.module) in module_constructor_lines:
if id(node.module) in module_constructor_traced:
module_constructor_traced.remove(id(node.module))
del module_constructor_lines[id(node.module)]
# Update module for node
node.module = module
# Update node map
del self.nodes_map[old_unique_name]
self.nodes_map[node.unique_name] = node
# Update previous node name for next tensors
for t in node.next_tensors:
self.tensor_pre_node_dict[id(t)] = self.tensor_pre_node_dict[id(t)].replace(
old_unique_name, node.unique_name
)
# Update previous node name for next nodes (TraceFunction)
for n in node.next_nodes:
for i, pt in enumerate(n.prev_tensors):
for nt in node.next_tensors:
if id(nt) == id(pt):
idx = n.prev_indices[i]
if type(n.module) is TraceFunction:
prev_unique_name = tensor_name_from_parts(old_unique_name, idx, is_constant_node)
next_unique_name = tensor_name_from_parts(node.unique_name, idx, is_new_constant_node)
log.debug(f'node rename: {prev_unique_name} -> {next_unique_name}')
n.module.replace_tensor_name(prev_unique_name, next_unique_name)
n.module.update_args_string()
break
def fuse_nodes_to_func(
self, nodes: typing.List[TraceNode], full_name: str, kind: str, func_type: str, is_class: bool
) -> None:
"""Fuses several nodes into one function"""
if len(nodes) > 1:
# Set the full name if the first node is already a TraceFunction
# Otherwise, we need to construct one.
next_nodes = []
next_tensors = []
if type(nodes[0].module) is TraceFunction:
next_nodes.extend(nodes[-1].next_nodes)
next_tensors.extend(nodes[-1].next_tensors)
last_node_unique_name = nodes[-1].unique_name
first_node_unique_name = nodes[0].unique_name
for node in nodes[1:]:
name = node.unique_name
del self.nodes_map[name]
self.forward_nodes.remove(node)
if id(node.module) in module_constructor_lines:
if id(node.module) in module_constructor_traced:
module_constructor_traced.remove(id(node.module))
del module_constructor_lines[id(node.module)]
node = nodes[0]
node.next_nodes.clear()
node.next_tensors.clear()
node.next_nodes.extend(next_nodes)
node.next_tensors.extend(next_tensors)
node.module.func_type = func_type
node.module.is_class = is_class
node.module.kind = kind
node.module.full_name = full_name
# Update next tensors
for t in node.next_tensors:
self.tensor_pre_node_dict[id(t)] = self.tensor_pre_node_dict[id(t)].replace(
last_node_unique_name, first_node_unique_name
)
for n in node.next_nodes:
# Update next nodes
for i, pn in enumerate(n.prev_nodes):
if pn.unique_name == last_node_unique_name:
n.prev_nodes[i] = node
for i, pt in enumerate(n.prev_tensors):
for nt in node.next_tensors:
if id(pt) == id(nt):
idx = n.prev_indices[i]
# Rewrite func calls in next nodes
if type(n.module) is TraceFunction:
old_unique_name = tensor_name_from_parts(last_node_unique_name, idx, False)
new_unique_name = tensor_name_from_parts(first_node_unique_name, idx, False)
n.module.replace_tensor_name(old_unique_name, new_unique_name)
n.module.update_args_string()
break
else:
# TODO: Implement this codepath
log.error('Module fusion requires the first node to be a TraceFunction.')
raise NotImplementedError
else:
log.warning('Calling fuse with less than 2 nodes is no-op.')
def remove_node(self, node: TraceNode) -> None:
"""Remove a node from the computation graph"""
if node not in self.forward_nodes:
log.error('Only forward nodes can be removed')
assert False
if len(node.prev_nodes) != 1:
log.error('You cannot remove a node with multiple input nodes')
assert False
if len(node.prev_tensors) != len(node.next_tensors):
log.error('You cannot remove a node in which the size of input tensors and the output tensors is different')
assert False
for idx, (prev_tensor, next_tensor) in enumerate(zip(node.prev_tensors, node.next_tensors)):
if prev_tensor.shape != next_tensor.shape:
log.error(f'The shape of the input/output at index {idx} mismatches')
log.error(f'The shape of the input tensor is {prev_tensor.shape}')
log.error(f'The shape of the output tensor is {next_tensor.shape}')
assert False
prev_node = node.prev_nodes[0]
next_nodes = node.next_nodes
tensor_dict = dict(zip(node.next_tensors, node.prev_tensors))
index_dict = dict(zip(node.next_tensors, node.prev_indices))
is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
is_prev_constant_node = type(prev_node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
old_unique_name = node.unique_name
# Deal with previous nodes
if node in prev_node.next_nodes:
prev_node.next_nodes.remove(node)
else:
log.error('Current node is not in the next nodes of the previous node')
assert False
prev_node.next_nodes.extend(next_nodes)
# Deal with next nodes
for n in next_nodes:
# Handle previous nodes
for i, pn in enumerate(n.prev_nodes):
if pn == node:
n.prev_nodes[i] = prev_node
# Handle previous tensors
for i, pt in enumerate(n.prev_tensors):
if pt in tensor_dict:
n.prev_tensors[i] = tensor_dict[pt]
old_idx = n.prev_indices[i]
n.prev_indices[i] = index_dict[pt]
new_idx = n.prev_indices[i]
# Rewrite func calls in next nodes
if type(n.module) is TraceFunction:
if n.module.args_string is not None:
prev_unique_name = tensor_name_from_parts(old_unique_name, old_idx, is_constant_node)
new_unique_name = tensor_name_from_parts(
prev_node.unique_name, new_idx, is_prev_constant_node
)
n.module.replace_tensor_name(prev_unique_name, new_unique_name)
n.module.update_args_string()
# Remove this node
self.forward_nodes.remove(node)
del self.nodes_map[node.unique_name]
if id(node.module) in module_constructor_lines:
if id(node.module) in module_constructor_traced:
module_constructor_traced.remove(id(node.module))
del module_constructor_lines[id(node.module)]
def load_overridable_modules():
if overridable_modules_loaded():
modules = overridable_modules
else:
modules = fetch_modules()
overridable_modules.extend(modules)
overridable_modules_loaded(True)
return modules
def load_overridable_funcs():
if overridable_funcs_loaded():
funcs = overridable_funcs
else:
funcs = fetch_funcs()
config = os.path.join(current_dir, 'configs/torch_quant_stub_func.yml')
funcs.update(fetch_funcs(config))
overridable_funcs.update(funcs)
overridable_funcs_loaded(True)
return funcs
def load_torch_overrides_funcs(funcs):
if not torch_overrides_funcs_loaded():
o_funcs, o_wrappers = prepare_torch_overrides_funcs(funcs)
torch_overrides_funcs.extend(o_funcs)
torch_overrides_wrappers.extend(o_wrappers)
torch_overrides_funcs_loaded(True)
else:
o_funcs, o_wrappers = torch_overrides_funcs, torch_overrides_wrappers
return o_funcs, o_wrappers
def load_creation_funcs():
if overridable_creation_funcs_loaded():
creation_funcs = overridable_creation_funcs
else:
creation_funcs = fetch_funcs(os.path.join(current_dir, 'configs/torch_creation_funcs_override.yml'))
overridable_creation_funcs.update(creation_funcs)
overridable_creation_funcs_loaded(True)
return creation_funcs
def load_tracking_modules():
if tracking_modules_loaded():
tracking_modules = torch_tracking_modules
else:
tracking_modules = fetch_modules(os.path.join(current_dir, 'configs/torch_tracking_modules.yml'))
torch_tracking_modules.extend(tracking_modules)
tracking_modules_loaded(True)
return tracking_modules
@contextlib.contextmanager
def patch_helper(
wrap_modules: bool = True,
wrap_funcs: bool = True,
wrap_creation_funcs: bool = True,
wrap_tracking_modules: bool = False,
):
"""Temporarily monkeypatches the functions and the modules in PyTorch."""
if wrap_modules:
if not modules_overrided():
modules = load_overridable_modules()
modules_overrided(True)
else:
wrap_modules = False
if wrap_funcs:
if not funcs_overrided():
funcs = load_overridable_funcs()
o_funcs, o_wrappers = load_torch_overrides_funcs(funcs)
funcs_overrided(True)
else:
wrap_funcs = False
if wrap_creation_funcs:
if not creation_funcs_overrided():
creation_funcs = load_creation_funcs()
creation_funcs_overrided(True)
else:
wrap_creation_funcs = False
if wrap_tracking_modules:
if not tracking_modules_overrided():
tracking_modules = load_tracking_modules()
tracking_modules_overrided(True)
else:
wrap_tracking_modules = False
if wrap_modules:
module_manager = patch_modules(modules, ('__init__', '__setattr__'), (new_init_gen, new_setattr_gen))
module_manager.__enter__()
if wrap_funcs:
tracked_funcs = [funcs, {torch.Tensor: ['__getattribute__']}]
wrappers = [new_func_gen, new_getattr_gen]
tracked_funcs.extend(o_funcs)
wrappers.extend(o_wrappers)
func_manager = patch_funcs(tracked_funcs, wrappers)
func_manager.__enter__()
if wrap_creation_funcs:
creation_func_manager = patch_funcs(creation_funcs, new_creation_func_gen)
creation_func_manager.__enter__()
if wrap_tracking_modules:
tracking_module_manager = patch_modules(tracking_modules, '__new__', new_init_tracking_gen)
tracking_module_manager.__enter__()
yield True
if wrap_modules:
module_manager.__exit__(None, None, None)
modules_overrided(False)
if wrap_funcs:
func_manager.__exit__(None, None, None)
funcs_overrided(False)
if wrap_creation_funcs:
creation_func_manager.__exit__(None, None, None)
creation_funcs_overrided(False)
if wrap_tracking_modules:
tracking_module_manager.__exit__(None, None, None)
tracking_modules_overrided(False)
def check_types(values: typing.Iterable) -> bool:
"""Checks whether unsupported types are in the args."""
for value in values:
if isinstance(value, (tuple, list)):
res = check_types(value)
if res is not None:
return res
elif not isinstance(value, (int, float, bool, str, type(None))):
return type(value).__name__
return None
def check_tensor_type(value) -> bool:
"""Check whether types are related to torch.Tensor."""
if isinstance(value, (tuple, list)):
for item in value:
res = check_tensor_type(item)
if res:
return res
elif type(value) is torch.Tensor:
return True
return False
def check_creation_args(args: typing.Iterable) -> typing.Tuple:
"""Cast arguments of type of Tensor to normal values"""
new_args = []
for arg in args:
if isinstance(arg, (tuple, list)):
new_args.append(check_creation_args(arg))
elif type(arg) is torch.Tensor:
if arg.dim() == 0:
new_args.append(arg.item())
else:
new_args.append(arg.tolist())
else:
new_args.append(arg)
return tuple(new_args)
def tensor_name_from_parts(node_name, node_idx=None, is_constant_node=False):
ns = 'self.' if is_constant_node else ''
if node_idx is None:
return f'{ns}{node_name}'
else:
if isinstance(node_idx, (list, tuple)):
indices_str = ''.join([f'[{i}]' for i in node_idx])
return f'{ns}{node_name}{indices_str}'
else:
return f'{ns}{node_name}[{node_idx}]'
def trace(
module: torch.nn.Module,
dummy_input: torch.Tensor,
eliminate_dead_graph: bool = False,
patch_torch_size: bool = False,
) -> TraceGraph:
"""main function for tracing"""
try:
with construct_trace_graph(module, dummy_input, eliminate_dead_graph, patch_torch_size) as new_graph:
new_graph.init()
return new_graph
except Exception:
traceback.print_exc()
if current_graph() is not None:
log.error(f'inputs: {[n.unique_name for n in current_graph().input_nodes]}')
log.error(f'forwards: {[n.unique_name for n in current_graph().forward_nodes]}')
log.error(f'outputs: {[n.unique_name for n in current_graph().output_nodes]}')
log.error(f'constants: {[n.unique_name for n in current_graph().constant_nodes]}')
quit()