def _call_impl()

in torchrec/modules/lazy_extension.py [0:0]


    def _call_impl(self, *input, **kwargs):  # noqa: C901
        # pyre-ignore[16]
        forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
        # If we don't have any hooks, we want to skip the rest of the logic in
        # this function, and just call forward.
        # pyre-ignore[16]
        if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
                or _global_forward_hooks or _global_forward_pre_hooks):
            return forward_call(*input, **kwargs)
        # Do not call functions when jit is used
        full_backward_hooks, non_full_backward_hooks = [], []
        if self._backward_hooks or _global_backward_hooks:
            # pyre-ignore[16]
            full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
        if _global_forward_pre_hooks or self._forward_pre_hooks:
            # pyre-ignore[60]: Concatenation not yet support for multiple variadic
            #  tuples: `*torch.nn.modules.module._global_forward_pre_hooks.values(),
            #  *self._forward_pre_hooks.values()`.
            for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
                if len(inspect.signature(hook).parameters) == 3:
                    result = hook(self, input, kwargs)
                else:
                    result = hook(self, input)
                if result is not None:
                    if not isinstance(result, tuple):
                        result = (result,)
                    input = result

        bw_hook = None
        if full_backward_hooks:
            bw_hook = hooks.BackwardHook(self, full_backward_hooks)
            input = bw_hook.setup_input_hook(input)

        result = forward_call(*input, **kwargs)
        if _global_forward_hooks or self._forward_hooks:
            # pyre-ignore[60]: Concatenation not yet support for multiple variadic
            #  tuples: `*torch.nn.modules.module._global_forward_hooks.values(),
            #  *self._forward_hooks.values()`.
            for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
                hook_result = hook(self, input, result)
                if hook_result is not None:
                    result = hook_result

        if bw_hook:
            result = bw_hook.setup_output_hook(result)

        # Handle the non-full backward hooks
        if non_full_backward_hooks:
            var = result
            while not isinstance(var, torch.Tensor):
                if isinstance(var, dict):
                    var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
                else:
                    var = var[0]
            # pyre-ignore[16]
            grad_fn = var.grad_fn
            if grad_fn is not None:
                for hook in non_full_backward_hooks:
                    wrapper = functools.partial(hook, self)
                    functools.update_wrapper(wrapper, hook)
                    grad_fn.register_hook(wrapper)
                # pyre-ignore[16]
                self._maybe_warn_non_full_backward_hook(input, result, grad_fn)

        return result