def _get_feature_counts()

in captum/attr/_core/feature_ablation.py [0:0]


    def _get_feature_counts(self, inputs, feature_mask, **kwargs):
        """return the numbers of input features"""
        if not feature_mask:
            return tuple(inp[0].numel() if inp.numel() else 0 for inp in inputs)

        return tuple(
            (mask.max() - mask.min()).item() + 1
            if mask is not None
            else (inp[0].numel() if inp.numel() else 0)
            for inp, mask in zip(inputs, feature_mask)
        )