in tinynn/graph/tracer.py [0:0]
def hook_modules(module):
"""Temporarily adds the hooks to a `nn.Module` for tracing"""
hooks = []
def register_submodule_tracer(module):
def _submodule_pre_tracer(module, input):
log.debug(f'pre tracer in _submodule_pre_tracer in {type(module).__name__}')
if lock():
skip_modules.add(weakref.ref(module))
lock(True)
def _submodule_tracer(module, inputs, outputs):
m_ref = weakref.ref(module)
if m_ref in skip_modules:
skip_modules.remove(m_ref)
return None
log.debug(f'tracer in _submodule_tracer in {type(module).__name__}')
node = TraceNode(module)
modified_outputs = noop_handler(node, inputs, outputs)
if modified_outputs is None:
add_forward_node(node, inputs, outputs)
else:
add_forward_node(node, inputs, modified_outputs)
lock(False)
return modified_outputs
module_unique_name = current_graph().module_unique_name_dict[id(module)]
if module_unique_name in current_graph().traced_modules:
log.debug(f"module {module_unique_name} is traced")
return None
related = False
if id(module) in module_constructor_traced:
if (
id(module) in module_constructor_lines
and module_constructor_weakrefs.get(id(module), type(None))() is not None
):
related = True
else:
if type(module) in overridable_modules:
related = True
else:
for m in overridable_modules:
if isinstance(module, m):
related = True
break
if related:
hooks.append(module.register_forward_pre_hook(_submodule_pre_tracer))
hooks.append(module.register_forward_hook(_submodule_tracer))
current_graph().related_modules.append(module_unique_name)
current_graph().traced_modules.append(module_unique_name)
return None
def _model_pre_tracer(module, inputs):
log.debug('pre tracer in _model_pre_tracer')
for i in inputs:
node = TraceNode(TraceFunction("input"))
add_input_node(node, i)
def _model_tracer(module, inputs, outputs):
log.debug('tracer in _model_tracer')
if type(outputs) is torch.Tensor:
node = TraceNode(TraceFunction("output"))
add_output_node(node, outputs)
elif isinstance(outputs, (list, tuple)):
for i in outputs:
if type(i) is torch.Tensor or (
isinstance(i, (list, tuple)) and all((type(x) is torch.Tensor for x in i))
):
node = TraceNode(TraceFunction("output"))
add_output_node(node, i)
else:
log.warning(
"Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
" tuple"
)
elif isinstance(outputs, dict):
for k, v in outputs.items():
if type(v) is torch.Tensor or (
isinstance(v, (list, tuple)) and all((type(x) is torch.Tensor for x in v))
):
node = TraceNode(TraceFunction("output"))
add_output_node(node, v)
else:
log.warning(
"Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
" tuple"
)
else:
log.warning(f'Output type is not supported: {type(outputs).__name__}, try to extract tensors from it')
for k in outputs.__dir__():
v = getattr(outputs, k)
if type(v) is torch.Tensor or (type(v) in (list, tuple) and all((type(x) is torch.Tensor for x in v))):
node = TraceNode(TraceFunction("output"))
add_output_node(node, v)
log.debug('trace: apply register_submodule_tracer')
module.apply(register_submodule_tracer)
log.debug('trace: add hooks')
hooks.append(module.register_forward_pre_hook(_model_pre_tracer))
hooks.append(module.register_forward_hook(_model_tracer))
yield module
for hook in hooks:
hook.remove()