in tensorwatch/saliency/inverter_util.py [0:0]
def linear_inverse(self, m, relevance_in):
if self.method == "e-rule":
m.in_tensor = m.in_tensor.pow(self.p)
w = m.weight.pow(self.p)
norm = F.linear(m.in_tensor, w, bias=None)
norm = norm + torch.sign(norm) * self.eps
relevance_in[norm == 0] = 0
norm[norm == 0] = 1
relevance_out = F.linear(relevance_in / norm,
w.t(), bias=None)
relevance_out *= m.in_tensor
del m.in_tensor, norm, w, relevance_in
return relevance_out
if self.method == "b-rule":
out_c, in_c = m.weight.size()
w = m.weight.repeat((4, 1))
# First and third channel repetition only contain the positive weights
w[:out_c][w[:out_c] < 0] = 0
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
# Second and fourth channel repetition with only the negative weights
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
w[-out_c:][w[-out_c:] > 0] = 0
# Repeat across channel dimension (pytorch always has channels first)
m.in_tensor = m.in_tensor.repeat((1, 4))
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
# Normalize such that the sum of the individual importance values
# of the input neurons divided by the norm
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
# Norm layer just sums the importance values of the inputs
# contributing to output j for each j. This will then serve as the normalization
# such that the contributions of the neurons sum to 1 in order to
# properly split up the relevance of j amongst its roots.
norm_shape = m.out_shape
norm_shape[1] *= 4
norm = torch.zeros(norm_shape).to(self.device)
for i in range(4):
norm[:, out_c * i:(i + 1) * out_c] = F.linear(
m.in_tensor[:, in_c * i:(i + 1) * in_c], w[out_c * i:(i + 1) * out_c], bias=None)
# Double number of output channels for positive and negative norm per
# channel.
norm_shape[1] = norm_shape[1] // 2
new_norm = torch.zeros(norm_shape).to(self.device)
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
norm = new_norm
# Some 'rare' neurons only receive either
# only positive or only negative inputs.
# Conservation of relevance does not hold, if we also
# rescale those neurons by (1+beta) or -beta.
# Therefore, catch those first and scale norm by
# the according value, such that it cancels in the fraction.
# First, however, avoid NaNs.
mask = norm == 0
# Set the norm to anything non-zero, e.g. 1.
# The actual inputs are zero at this point anyways, that
# is why norm is zero in the first place.
norm[mask] = 1
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
# The first out_c block corresponds to the positive norms,
# the second out_c block corresponds to the negative norms.
# We find the rare neurons by choosing those nodes per channel
# in which either the positive norm ([:, :out_c]) is zero, or
# the negative norm ([:, :out_c]) is zero.
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
# Also, catch new possibilities for norm == zero to avoid NaN..
# The actual value of norm again does not really matter, since
# the pre-factor will be zero in this case.
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
# Add stabilizer term to norm to avoid numerical instabilities.
norm += self.eps * torch.sign(norm)
input_relevance = relevance_in.squeeze(dim=-1).repeat(1, 4)
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2)
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2)
inv_w = w.t()
relevance_out = torch.zeros_like(m.in_tensor)
for i in range(4):
relevance_out[:, i*in_c:(i+1)*in_c] = F.linear(
input_relevance[:, i*out_c:(i+1)*out_c],
weight=inv_w[:, i*out_c:(i+1)*out_c], bias=None)
relevance_out *= m.in_tensor
sum_weights = torch.zeros([in_c, in_c * 4, 1]).to(self.device)
for i in range(in_c):
sum_weights[i, i::in_c] = 1
relevance_out = F.conv1d(relevance_out[:, :, None], weight=sum_weights, bias=None)
del sum_weights, input_relevance, norm, rare_neurons, \
mask, new_norm, m.in_tensor, w, inv_w
return relevance_out