in captum/attr/_utils/batching.py [0:0]
def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
if curr_arg is None:
return None
is_tuple = isinstance(curr_arg, tuple)
if not is_tuple:
curr_arg = (curr_arg,)
selected_arg = []
for i in range(len(curr_arg)):
if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
selected_arg.append(curr_arg[i][index : index + 1])
else:
selected_arg.append(curr_arg[i])
return _format_output(is_tuple, tuple(selected_arg))