in captum/metrics/_core/infidelity.py [0:0]
def infidelity_perturb_func_decorator(multipy_by_inputs: bool = True) -> Callable:
r"""An auxiliary, decorator function that helps with computing
perturbations given perturbed inputs. It can be useful for cases
when `pertub_func` returns only perturbed inputs and we
internally compute the perturbations as
(input - perturbed_input) / (input - baseline) if
multipy_by_inputs is set to True and
(input - perturbed_input) otherwise.
If users decorate their `pertub_func` with
`@infidelity_perturb_func_decorator` function then their `pertub_func`
needs to only return perturbed inputs.
Args:
multipy_by_inputs (bool): Indicates whether model inputs'
multiplier is factored in the computation of
attribution scores.
"""
def sub_infidelity_perturb_func_decorator(pertub_func: Callable) -> Callable:
r"""
Args:
pertub_func(callable): Input perturbation function that takes inputs
and optionally baselines and returns perturbed inputs
Returns:
default_perturb_func(callable): Internal default perturbation
function that computes the perturbations internally and returns
perturbations and perturbed inputs.
Examples::
>>> @infidelity_perturb_func_decorator(True)
>>> def perturb_fn(inputs):
>>> noise = torch.tensor(np.random.normal(0, 0.003,
>>> inputs.shape)).float()
>>> return inputs - noise
>>> # Computes infidelity score using `perturb_fn`
>>> infidelity = infidelity(model, perturb_fn, input, ...)
"""
def default_perturb_func(
inputs: TensorOrTupleOfTensorsGeneric, baselines: BaselineType = None
):
r""" """
inputs_perturbed = (
pertub_func(inputs, baselines)
if baselines is not None
else pertub_func(inputs)
)
inputs_perturbed = _format_tensor_into_tuples(inputs_perturbed)
inputs = _format_tensor_into_tuples(inputs)
baselines = _format_baseline(baselines, inputs)
if baselines is None:
perturbations = tuple(
safe_div(
input - input_perturbed,
input,
default_denom=1.0,
)
if multipy_by_inputs
else input - input_perturbed
for input, input_perturbed in zip(inputs, inputs_perturbed)
)
else:
perturbations = tuple(
safe_div(
input - input_perturbed,
input - baseline,
default_denom=1.0,
)
if multipy_by_inputs
else input - input_perturbed
for input, input_perturbed, baseline in zip(
inputs, inputs_perturbed, baselines
)
)
return perturbations, inputs_perturbed
return default_perturb_func
return sub_infidelity_perturb_func_decorator