in captum/attr/_core/feature_ablation.py [0:0]
def _get_feature_counts(self, inputs, feature_mask, **kwargs):
"""return the numbers of input features"""
if not feature_mask:
return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)
return tuple(
(mask.max() - mask.min()).item() + 1
if mask is not None
else (inp[0].numel() if inp.numel() else 0)
for inp, mask in zip(inputs, feature_mask)
)