def _check_valid_module()

in captum/attr/_core/deep_lift.py [0:0]


def _check_valid_module(inputs_grad_fn, outputs) -> bool:
    def is_output_cloned(output_fn, input_grad_fn) -> bool:
        """
        Checks if the output has been cloned. This happens especially in case of
        layer deeplift.
        """
        return (
            output_fn[0].next_functions is not None
            and output_fn[0].next_functions[0][0] == input_grad_fn
        )

    curr_fn = outputs.grad_fn
    first_next = curr_fn.next_functions[0]
    try:
        # if `inputs` in the input to the network then the grad_fn is None and
        # for that input backward_hook isn't computed. That's the reason why we
        # need to check on `inputs_grad_fns[first_next[1]]` being None.
        return (
            inputs_grad_fn is None
            or first_next[0] == inputs_grad_fn
            or is_output_cloned(first_next, inputs_grad_fn)
        )
    except IndexError:
        return False