in tensorwatch/saliency/inverter_util.py [0:0]
def conv_nd_inverse(self, m, relevance_in):
# In case the output had been reshaped for a linear layer,
# make sure the relevance is put into the same shape as before.
relevance_in = relevance_in.view(m.out_shape)
# Get required values from layer
inv_conv_nd = self.get_inv_conv_method(m)
conv_nd = self.get_conv_method(m)
if self.method == "e-rule":
with torch.no_grad():
m.in_tensor = m.in_tensor.pow(self.p).detach()
w = m.weight.pow(self.p).detach()
norm = conv_nd(m.in_tensor, weight=w, bias=None,
stride=m.stride, padding=m.padding,
groups=m.groups)
norm = norm + torch.sign(norm) * self.eps
relevance_in[norm == 0] = 0
norm[norm == 0] = 1
relevance_out = inv_conv_nd(relevance_in/norm,
weight=w, bias=None,
padding=m.padding, stride=m.stride,
groups=m.groups)
relevance_out *= m.in_tensor
del m.in_tensor, norm, w
return relevance_out
if self.method == "b-rule":
with torch.no_grad():
w = m.weight
out_c, in_c = m.out_channels, m.in_channels
repeats = np.array(np.ones_like(w.size()).flatten(), dtype=int)
repeats[0] *= 4
w = w.repeat(tuple(repeats))
# First and third channel repetition only contain the positive weights
w[:out_c][w[:out_c] < 0] = 0
w[2 * out_c:3 * out_c][w[2 * out_c:3 * out_c] < 0] = 0
# Second and fourth channel repetition with only the negative weights
w[1 * out_c:2 * out_c][w[1 * out_c:2 * out_c] > 0] = 0
w[-out_c:][w[-out_c:] > 0] = 0
repeats = np.array(np.ones_like(m.in_tensor.size()).flatten(), dtype=int)
repeats[1] *= 4
# Repeat across channel dimension (pytorch always has channels first)
m.in_tensor = m.in_tensor.repeat(tuple(repeats))
m.in_tensor[:, :in_c][m.in_tensor[:, :in_c] < 0] = 0
m.in_tensor[:, -in_c:][m.in_tensor[:, -in_c:] < 0] = 0
m.in_tensor[:, 1 * in_c:3 * in_c][m.in_tensor[:, 1 * in_c:3 * in_c] > 0] = 0
groups = 4
# Normalize such that the sum of the individual importance values
# of the input neurons divided by the norm
# yields 1 for an output neuron j if divided by norm (v_ij in paper).
# Norm layer just sums the importance values of the inputs
# contributing to output j for each j. This will then serve as the normalization
# such that the contributions of the neurons sum to 1 in order to
# properly split up the relevance of j amongst its roots.
norm = conv_nd(m.in_tensor, weight=w, bias=None, stride=m.stride,
padding=m.padding, dilation=m.dilation, groups=groups * m.groups)
# Double number of output channels for positive and negative norm per
# channel. Using list with out_tensor.size() allows for ND generalization
new_shape = m.out_shape
new_shape[1] *= 2
new_norm = torch.zeros(new_shape).to(self.device)
new_norm[:, :out_c] = norm[:, :out_c] + norm[:, out_c:2 * out_c]
new_norm[:, out_c:] = norm[:, 2 * out_c:3 * out_c] + norm[:, 3 * out_c:]
norm = new_norm
# Some 'rare' neurons only receive either
# only positive or only negative inputs.
# Conservation of relevance does not hold, if we also
# rescale those neurons by (1+beta) or -beta.
# Therefore, catch those first and scale norm by
# the according value, such that it cancels in the fraction.
# First, however, avoid NaNs.
mask = norm == 0
# Set the norm to anything non-zero, e.g. 1.
# The actual inputs are zero at this point anyways, that
# is why norm is zero in the first place.
norm[mask] = 1
# The norm in the b-rule has shape (N, 2*out_c, *spatial_dims).
# The first out_c block corresponds to the positive norms,
# the second out_c block corresponds to the negative norms.
# We find the rare neurons by choosing those nodes per channel
# in which either the positive norm ([:, :out_c]) is zero, or
# the negative norm ([:, :out_c]) is zero.
rare_neurons = (mask[:, :out_c] + mask[:, out_c:])
# Also, catch new possibilities for norm == zero to avoid NaN..
# The actual value of norm again does not really matter, since
# the pre-factor will be zero in this case.
norm[:, :out_c][rare_neurons] *= 1 if self.beta == -1 else 1 + self.beta
norm[:, out_c:][rare_neurons] *= 1 if self.beta == 0 else -self.beta
# Add stabilizer term to norm to avoid numerical instabilities.
norm += self.eps * torch.sign(norm)
spatial_dims = [1] * len(relevance_in.size()[2:])
input_relevance = relevance_in.repeat(1, 4, *spatial_dims)
input_relevance[:, :2*out_c] *= (1+self.beta)/norm[:, :out_c].repeat(1, 2, *spatial_dims)
input_relevance[:, 2*out_c:] *= -self.beta/norm[:, out_c:].repeat(1, 2, *spatial_dims)
# Each of the positive / negative entries needs its own
# convolution. TODO: Can this be done in groups, too?
relevance_out = torch.zeros_like(m.in_tensor)
# Weird code to make up for loss of size due to stride
tmp_result = result = None
for i in range(4):
tmp_result = inv_conv_nd(
input_relevance[:, i*out_c:(i+1)*out_c],
weight=w[i*out_c:(i+1)*out_c],
bias=None, padding=m.padding, stride=m.stride,
groups=m.groups)
result = torch.zeros_like(relevance_out[:, i*in_c:(i+1)*in_c])
tmp_size = tmp_result.size()
slice_list = [slice(0, l) for l in tmp_size]
result[slice_list] += tmp_result
relevance_out[:, i*in_c:(i+1)*in_c] = result
relevance_out *= m.in_tensor
sum_weights = torch.zeros([in_c, in_c * 4, *spatial_dims]).to(self.device)
for i in range(m.in_channels):
sum_weights[i, i::in_c] = 1
relevance_out = conv_nd(relevance_out, weight=sum_weights, bias=None)
del sum_weights, m.in_tensor, result, mask, rare_neurons, norm, \
new_norm, input_relevance, tmp_result, w
return relevance_out