in captum/_utils/sample_gradient.py [0:0]
def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
assert (
loss_mode.upper() in LossMode.__members__
), f"Provided loss mode {loss_mode} is not valid"
mode = LossMode[loss_mode.upper()]
self.model.zero_grad()
loss_blob.backward(gradient=torch.ones_like(loss_blob))
for module in self.gradient_dict:
sample_grad_fn = SUPPORTED_MODULES[type(module)]
activations = self.activation_dict[module]
gradients = self.gradient_dict[module]
assert len(activations) == len(gradients), (
"Number of saved activations do not match number of saved gradients."
" This may occur if multiple forward passes are run without calling"
" reset or computing param gradients."
)
# Reversing grads since when a module is used multiple times,
# the activations will be aligned with the reverse order of the gradients,
# since the order is reversed in backprop.
for i, (act, grad) in enumerate(
zip(activations, list(reversed(gradients)))
):
mult = 1 if mode is LossMode.SUM else act.shape[0]
sample_grad_fn(module, act, grad * mult, reset=(i == 0))
self._reset()