def hook_modules()

in tinynn/graph/tracer.py [0:0]


def hook_modules(module):
    """Temporarily adds the hooks to a `nn.Module` for tracing"""
    hooks = []

    def register_submodule_tracer(module):
        def _submodule_pre_tracer(module, input):
            log.debug(f'pre tracer in _submodule_pre_tracer in {type(module).__name__}')
            if lock():
                skip_modules.add(weakref.ref(module))

            lock(True)

        def _submodule_tracer(module, inputs, outputs):
            m_ref = weakref.ref(module)
            if m_ref in skip_modules:
                skip_modules.remove(m_ref)
                return None

            log.debug(f'tracer in _submodule_tracer in {type(module).__name__}')
            node = TraceNode(module)
            modified_outputs = noop_handler(node, inputs, outputs)
            if modified_outputs is None:
                add_forward_node(node, inputs, outputs)
            else:
                add_forward_node(node, inputs, modified_outputs)
            lock(False)
            return modified_outputs

        module_unique_name = current_graph().module_unique_name_dict[id(module)]
        if module_unique_name in current_graph().traced_modules:
            log.debug(f"module {module_unique_name} is traced")
            return None

        related = False
        if id(module) in module_constructor_traced:
            if (
                id(module) in module_constructor_lines
                and module_constructor_weakrefs.get(id(module), type(None))() is not None
            ):
                related = True
        else:
            if type(module) in overridable_modules:
                related = True
            else:
                for m in overridable_modules:
                    if isinstance(module, m):
                        related = True
                        break

        if related:
            hooks.append(module.register_forward_pre_hook(_submodule_pre_tracer))
            hooks.append(module.register_forward_hook(_submodule_tracer))
            current_graph().related_modules.append(module_unique_name)

        current_graph().traced_modules.append(module_unique_name)
        return None

    def _model_pre_tracer(module, inputs):
        log.debug('pre tracer in _model_pre_tracer')
        for i in inputs:
            node = TraceNode(TraceFunction("input"))
            add_input_node(node, i)

    def _model_tracer(module, inputs, outputs):
        log.debug('tracer in _model_tracer')
        if type(outputs) is torch.Tensor:
            node = TraceNode(TraceFunction("output"))
            add_output_node(node, outputs)
        elif isinstance(outputs, (list, tuple)):
            for i in outputs:
                if type(i) is torch.Tensor or (
                    isinstance(i, (list, tuple)) and all((type(x) is torch.Tensor for x in i))
                ):
                    node = TraceNode(TraceFunction("output"))
                    add_output_node(node, i)
                else:
                    log.warning(
                        "Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
                        " tuple"
                    )
        elif isinstance(outputs, dict):
            for k, v in outputs.items():
                if type(v) is torch.Tensor or (
                    isinstance(v, (list, tuple)) and all((type(x) is torch.Tensor for x in v))
                ):
                    node = TraceNode(TraceFunction("output"))
                    add_output_node(node, v)
                else:
                    log.warning(
                        "Only tensors or list, tuple of tensors are supported when nested in a class, dict, list or"
                        " tuple"
                    )
        else:
            log.warning(f'Output type is not supported: {type(outputs).__name__}, try to extract tensors from it')
            for k in outputs.__dir__():
                v = getattr(outputs, k)
                if type(v) is torch.Tensor or (type(v) in (list, tuple) and all((type(x) is torch.Tensor for x in v))):
                    node = TraceNode(TraceFunction("output"))
                    add_output_node(node, v)

    log.debug('trace: apply register_submodule_tracer')
    module.apply(register_submodule_tracer)

    log.debug('trace: add hooks')
    hooks.append(module.register_forward_pre_hook(_model_pre_tracer))
    hooks.append(module.register_forward_hook(_model_tracer))

    yield module

    for hook in hooks:
        hook.remove()