in torchrec/modules/lazy_extension.py [0:0]
def _call_impl(self, *input, **kwargs): # noqa: C901
# pyre-ignore[16]
forward_call = (self._slow_forward if torch._C._get_tracing_state() else self.forward)
# If we don't have any hooks, we want to skip the rest of the logic in
# this function, and just call forward.
# pyre-ignore[16]
if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
or _global_forward_hooks or _global_forward_pre_hooks):
return forward_call(*input, **kwargs)
# Do not call functions when jit is used
full_backward_hooks, non_full_backward_hooks = [], []
if self._backward_hooks or _global_backward_hooks:
# pyre-ignore[16]
full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks()
if _global_forward_pre_hooks or self._forward_pre_hooks:
# pyre-ignore[60]: Concatenation not yet support for multiple variadic
# tuples: `*torch.nn.modules.module._global_forward_pre_hooks.values(),
# *self._forward_pre_hooks.values()`.
for hook in (*_global_forward_pre_hooks.values(), *self._forward_pre_hooks.values()):
if len(inspect.signature(hook).parameters) == 3:
result = hook(self, input, kwargs)
else:
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
bw_hook = None
if full_backward_hooks:
bw_hook = hooks.BackwardHook(self, full_backward_hooks)
input = bw_hook.setup_input_hook(input)
result = forward_call(*input, **kwargs)
if _global_forward_hooks or self._forward_hooks:
# pyre-ignore[60]: Concatenation not yet support for multiple variadic
# tuples: `*torch.nn.modules.module._global_forward_hooks.values(),
# *self._forward_hooks.values()`.
for hook in (*_global_forward_hooks.values(), *self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if bw_hook:
result = bw_hook.setup_output_hook(result)
# Handle the non-full backward hooks
if non_full_backward_hooks:
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
# pyre-ignore[16]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in non_full_backward_hooks:
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
# pyre-ignore[16]
self._maybe_warn_non_full_backward_hook(input, result, grad_fn)
return result