in captum/_utils/common.py [0:0]
def _select_targets(output: Tensor, target: TargetType) -> Tensor:
if target is None:
return output
num_examples = output.shape[0]
dims = len(output.shape)
device = output.device
if isinstance(target, (int, tuple)):
return _verify_select_column(output, target)
elif isinstance(target, torch.Tensor):
if torch.numel(target) == 1 and isinstance(target.item(), int):
return _verify_select_column(output, cast(int, target.item()))
elif len(target.shape) == 1 and torch.numel(target) == num_examples:
assert dims == 2, "Output must be 2D to select tensor of targets."
return torch.gather(output, 1, target.reshape(len(output), 1))
else:
raise AssertionError(
"Tensor target dimension %r is not valid. %r"
% (target.shape, output.shape)
)
elif isinstance(target, list):
assert len(target) == num_examples, "Target list length does not match output!"
if isinstance(target[0], int):
assert dims == 2, "Output must be 2D to select tensor of targets."
return torch.gather(
output, 1, torch.tensor(target, device=device).reshape(len(output), 1)
)
elif isinstance(target[0], tuple):
return torch.stack(
[
output[(i,) + cast(Tuple, targ_elem)]
for i, targ_elem in enumerate(target)
]
)
else:
raise AssertionError("Target element type in list is not valid.")
else:
raise AssertionError("Target type %r is not valid." % target)