in captum/attr/_core/deep_lift.py [0:0]
def _check_valid_module(inputs_grad_fn, outputs) -> bool:
def is_output_cloned(output_fn, input_grad_fn) -> bool:
"""
Checks if the output has been cloned. This happens especially in case of
layer deeplift.
"""
return (
output_fn[0].next_functions is not None
and output_fn[0].next_functions[0][0] == input_grad_fn
)
curr_fn = outputs.grad_fn
first_next = curr_fn.next_functions[0]
try:
# if `inputs` in the input to the network then the grad_fn is None and
# for that input backward_hook isn't computed. That's the reason why we
# need to check on `inputs_grad_fns[first_next[1]]` being None.
return (
inputs_grad_fn is None
or first_next[0] == inputs_grad_fn
or is_output_cloned(first_next, inputs_grad_fn)
)
except IndexError:
return False