def construct_feature_mask()

in captum/attr/_core/lime.py [0:0]


def construct_feature_mask(feature_mask, formatted_inputs):
    if feature_mask is None:
        feature_mask, num_interp_features = _construct_default_feature_mask(
            formatted_inputs
        )
    else:
        feature_mask = _format_tensor_into_tuples(feature_mask)
        min_interp_features = int(
            min(
                torch.min(single_mask).item()
                for single_mask in feature_mask
                if single_mask.numel()
            )
        )
        if min_interp_features != 0:
            warnings.warn(
                "Minimum element in feature mask is not 0, shifting indices to"
                " start at 0."
            )
            feature_mask = tuple(
                single_mask - min_interp_features for single_mask in feature_mask
            )

        num_interp_features = int(
            max(
                torch.max(single_mask).item()
                for single_mask in feature_mask
                if single_mask.numel()
            )
            + 1
        )
    return feature_mask, num_interp_features