def _select_example()

in captum/attr/_utils/batching.py [0:0]


def _select_example(curr_arg: Any, index: int, bsz: int) -> Any:
    if curr_arg is None:
        return None
    is_tuple = isinstance(curr_arg, tuple)
    if not is_tuple:
        curr_arg = (curr_arg,)
    selected_arg = []
    for i in range(len(curr_arg)):
        if isinstance(curr_arg[i], (Tensor, list)) and len(curr_arg[i]) == bsz:
            selected_arg.append(curr_arg[i][index : index + 1])
        else:
            selected_arg.append(curr_arg[i])
    return _format_output(is_tuple, tuple(selected_arg))