in captum/attr/_utils/summarizer.py [0:0]
def update(self, x: Union[float, Tensor, Tuple[Union[float, Tensor], ...]]):
r"""
Calls `update` on each `Stat` object within the summarizer
Args:
x (Tensor or Tuple[Tensor, ...]):
The input(s) you wish to summarize
"""
if self._is_inputs_tuple is None:
self._is_inputs_tuple = isinstance(x, tuple)
else:
# we want input to be consistently a single input or a tuple
assert not (self._is_inputs_tuple ^ isinstance(x, tuple))
from captum._utils.common import _format_float_or_tensor_into_tuples
x = _format_float_or_tensor_into_tuples(x)
for i, inp in enumerate(x):
if i >= len(self._summarizers):
# _summarizers[i] is a new SummarizerSingleTensor, which
# aims to summarize input i (i.e. x[i])
#
# Thus, we must copy our stats, as otherwise
# in the best case the statistics for each input will be mangled
# and in the worst case we will run into an error due to different
# dimensionality in the input tensors tensors (i.e.
# x[i].shape != x[j].shape for some pair i, j)
stats = self._copy_stats()
self._summarizers.append(
SummarizerSingleTensor(
stats=stats, summary_stats_indices=self._summary_stats_indicies
)
)
if not isinstance(inp, torch.Tensor):
inp = torch.tensor(inp, dtype=torch.float)
self._summarizers[i].update(inp)