in captum/metrics/_core/sensitivity.py [0:0]
def _next_sensitivity_max(current_n_perturb_samples: int) -> Tensor:
inputs_perturbed = _generate_perturbations(current_n_perturb_samples)
# copy kwargs and update some of the arguments that need to be expanded
nonlocal kwarg_expanded_for
nonlocal kwargs_copy
if (
kwarg_expanded_for is None
or kwarg_expanded_for != current_n_perturb_samples
):
kwarg_expanded_for = current_n_perturb_samples
kwargs_copy = deepcopy(kwargs)
_expand_and_update_additional_forward_args(
current_n_perturb_samples, kwargs_copy
)
_expand_and_update_target(current_n_perturb_samples, kwargs_copy)
if "baselines" in kwargs:
baselines = kwargs["baselines"]
baselines = _format_baseline(
baselines, cast(Tuple[Tensor, ...], inputs)
)
if (
isinstance(baselines[0], Tensor)
and baselines[0].shape == inputs[0].shape
):
_expand_and_update_baselines(
cast(Tuple[Tensor, ...], inputs),
current_n_perturb_samples,
kwargs_copy,
)
expl_perturbed_inputs = explanation_func(inputs_perturbed, **kwargs_copy)
# tuplize `expl_perturbed_inputs` in case it is not
expl_perturbed_inputs = _format_tensor_into_tuples(expl_perturbed_inputs)
expl_inputs_expanded = tuple(
expl_input.repeat_interleave(current_n_perturb_samples, dim=0)
for expl_input in expl_inputs
)
sensitivities = torch.cat(
[
(expl_input - expl_perturbed).view(expl_perturbed.size(0), -1)
for expl_perturbed, expl_input in zip(
expl_perturbed_inputs, expl_inputs_expanded
)
],
dim=1,
)
# compute the norm of original input explanations
expl_inputs_norm_expanded = torch.norm(
torch.cat(
[expl_input.view(expl_input.size(0), -1) for expl_input in expl_inputs],
dim=1,
),
p=norm_ord,
dim=1,
keepdim=True,
).repeat_interleave(current_n_perturb_samples, dim=0)
expl_inputs_norm_expanded = torch.where(
expl_inputs_norm_expanded == 0.0,
torch.tensor(
1.0,
device=expl_inputs_norm_expanded.device,
dtype=expl_inputs_norm_expanded.dtype,
),
expl_inputs_norm_expanded,
)
# compute the norm for each input noisy example
sensitivities_norm = (
torch.norm(sensitivities, p=norm_ord, dim=1, keepdim=True)
/ expl_inputs_norm_expanded
)
return max_values(sensitivities_norm.view(bsz, -1))