in captum/attr/_core/lime.py [0:0]
def default_from_interp_rep_transform(curr_sample, original_inputs, **kwargs):
assert (
"feature_mask" in kwargs
), "Must provide feature_mask to use default interpretable representation transform"
assert (
"baselines" in kwargs
), "Must provide baselines to use default interpretable representation transfrom"
feature_mask = kwargs["feature_mask"]
if isinstance(feature_mask, Tensor):
binary_mask = curr_sample[0][feature_mask].bool()
return (
binary_mask.to(original_inputs.dtype) * original_inputs
+ (~binary_mask).to(original_inputs.dtype) * kwargs["baselines"]
)
else:
binary_mask = tuple(
curr_sample[0][feature_mask[j]].bool() for j in range(len(feature_mask))
)
return tuple(
binary_mask[j].to(original_inputs[j].dtype) * original_inputs[j]
+ (~binary_mask[j]).to(original_inputs[j].dtype) * kwargs["baselines"][j]
for j in range(len(feature_mask))
)