def default_from_interp_rep_transform()

in captum/attr/_core/lime.py [0:0]


def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
    assert (
        "feature_mask" in kwargs
    ), "Must provide feature_mask to use default interpretable representation transform"
    assert (
        "baselines" in kwargs
    ), "Must provide baselines to use default interpretable representation transfrom"
    feature_mask = kwargs["feature_mask"]
    if isinstance(feature_mask, Tensor):
        binary_mask = curr_sample[0][feature_mask].bool()
        return (
            binary_mask.to(original_inputs.dtype) * original_inputs
            + (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
        )
    else:
        binary_mask = tuple(
            curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
        )
        return tuple(
            binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
            + (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
            for j in range(len(feature_mask))
        )