def infidelity_perturb_func_decorator()

in captum/metrics/_core/infidelity.py [0:0]


def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable:
    r"""An auxiliary, decorator function that helps with computing
    perturbations given perturbed inputs. It can be useful for cases
    when `pertub_func` returns only perturbed inputs and we
    internally compute the perturbations as
    (input - perturbed_input) / (input - baseline) if
    multipy_by_inputs is set to True and
    (input - perturbed_input) otherwise.

    If users decorate their `pertub_func` with
    `@infidelity_perturb_func_decorator` function then their `pertub_func`
    needs to only return perturbed inputs.

    Args:

        multipy_by_inputs (bool): Indicates whether model inputs'
                multiplier is factored in the computation of
                attribution scores.

    """

    def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
        r"""
        Args:

            pertub_func(callable): Input perturbation function that takes inputs
                and optionally baselines and returns perturbed inputs

        Returns:

            default_perturb_func(callable): Internal default perturbation
            function that computes the perturbations internally and returns
            perturbations and perturbed inputs.

        Examples::
            >>> @infidelity_perturb_func_decorator(True)
            >>> def perturb_fn(inputs):
            >>>    noise = torch.tensor(np.random.normal(0, 0.003,
            >>>                         inputs.shape)).float()
            >>>    return inputs - noise
            >>> # Computes infidelity score using `perturb_fn`
            >>> infidelity = infidelity(model, perturb_fn, input, ...)

        """

        def default_perturb_func(
            inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None
        ):
            r""" """
            inputs_perturbed = (
                pertub_func(inputs, baselines)
                if baselines is not None
                else pertub_func(inputs)
            )
            inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)
            inputs = _format_tensor_into_tuples(inputs)
            baselines = _format_baseline(baselines, inputs)
            if baselines is None:
                perturbations = tuple(
                    safe_div(
                        input - input_perturbed,
                        input,
                        default_denom=1.0,
                    )
                    if multipy_by_inputs
                    else input - input_perturbed
                    for input, input_perturbed in zip(inputs, inputs_perturbed)
                )
            else:
                perturbations = tuple(
                    safe_div(
                        input - input_perturbed,
                        input - baseline,
                        default_denom=1.0,
                    )
                    if multipy_by_inputs
                    else input - input_perturbed
                    for input, input_perturbed, baseline in zip(
                        inputs, inputs_perturbed, baselines
                    )
                )
            return perturbations, inputs_perturbed

        return default_perturb_func

    return sub_infidelity_perturb_func_decorator