in captum/attr/_core/layer/layer_lrp.py [0:0]
def _get_single_output_relevance(self, layer, output):
if self.attribute_to_layer_input:
normalized_relevances = layer.rule.relevance_input
else:
normalized_relevances = layer.rule.relevance_output
key_list = _sort_key_list(list(normalized_relevances.keys()), self.device_ids)
normalized_relevances = _reduce_list(
[normalized_relevances[device_id] for device_id in key_list]
)
if isinstance(normalized_relevances, tuple):
return tuple(
normalized_relevance
* output.reshape((-1,) + (1,) * (normalized_relevance.dim() - 1))
for normalized_relevance in normalized_relevances
)
else:
return normalized_relevances * output.reshape(
(-1,) + (1,) * (normalized_relevances.dim() - 1)
)