in tensorwatch/saliency/epsilon_lrp.py [0:0]
def innvestigate(self, in_tensor=None, rel_for_class=None):
"""
Method for 'innvestigating' the model with the LRP rule chosen at
the initialization of the InnvestigateModel.
Args:
in_tensor: Input for which to evaluate the LRP algorithm.
If input is None, the last evaluation is used.
If no evaluation has been performed since initialization,
an error is raised.
rel_for_class (int): Index of the class for which the relevance
distribution is to be analyzed. If None, the 'winning' class
is used for indexing.
Returns:
Model output and relevances of nodes in the input layer.
In order to get relevance distributions in other layers, use
the get_r_values_per_layer method.
"""
if self.r_values_per_layer is not None:
for elt in self.r_values_per_layer:
del elt
self.r_values_per_layer = None
with torch.no_grad():
# Check if innvestigation can be performed.
if in_tensor is None and self.prediction is None:
raise RuntimeError("Model needs to be evaluated at least "
"once before an innvestigation can be "
"performed. Please evaluate model first "
"or call innvestigate with a new input to "
"evaluate.")
# Evaluate the model anew if a new input is supplied.
if in_tensor is not None:
self.evaluate(in_tensor)
# If no class index is specified, analyze for class
# with highest prediction.
if rel_for_class is None:
# Default behaviour is innvestigating the output
# on an arg-max-basis, if no class is specified.
org_shape = self.prediction.size()
# Make sure shape is just a 1D vector per batch example.
self.prediction = self.prediction.view(org_shape[0], -1)
max_v, _ = torch.max(self.prediction, dim=1, keepdim=True)
only_max_score = torch.zeros_like(self.prediction).to(self.device)
only_max_score[max_v == self.prediction] = self.prediction[max_v == self.prediction]
relevance_tensor = only_max_score.view(org_shape)
self.prediction.view(org_shape)
else:
org_shape = self.prediction.size()
self.prediction = self.prediction.view(org_shape[0], -1)
only_max_score = torch.zeros_like(self.prediction).to(self.device)
only_max_score[:, rel_for_class] += self.prediction[:, rel_for_class]
relevance_tensor = only_max_score.view(org_shape)
self.prediction.view(org_shape)
# We have to iterate through the model backwards.
# The module list is computed for every forward pass
# by the model inverter.
rev_model = self.inverter.module_list[::-1]
relevance = relevance_tensor.detach()
del relevance_tensor
# List to save relevance distributions per layer
r_values_per_layer = [relevance]
for layer in rev_model:
# Compute layer specific backwards-propagation of relevance values
relevance = self.inverter.compute_propagated_relevance(layer, relevance)
r_values_per_layer.append(relevance.cpu())
self.r_values_per_layer = r_values_per_layer
del relevance
if self.device.type == "cuda":
torch.cuda.empty_cache()
return self.prediction, r_values_per_layer[-1]