def compute_param_sample_gradients()

in captum/_utils/sample_gradient.py [0:0]


    def compute_param_sample_gradients(self, loss_blob, loss_mode="mean"):
        assert (
            loss_mode.upper() in LossMode.__members__
        ), f"Provided loss mode {loss_mode} is not valid"
        mode = LossMode[loss_mode.upper()]

        self.model.zero_grad()
        loss_blob.backward(gradient=torch.ones_like(loss_blob))

        for module in self.gradient_dict:
            sample_grad_fn = SUPPORTED_MODULES[type(module)]
            activations = self.activation_dict[module]
            gradients = self.gradient_dict[module]
            assert len(activations) == len(gradients), (
                "Number of saved activations do not match number of saved gradients."
                " This may occur if multiple forward passes are run without calling"
                " reset or computing param gradients."
            )
            # Reversing grads since when a module is used multiple times,
            # the activations will be aligned with the reverse order of the gradients,
            # since the order is reversed in backprop.
            for i, (act, grad) in enumerate(
                zip(activations, list(reversed(gradients)))
            ):
                mult = 1 if mode is LossMode.SUM else act.shape[0]
                sample_grad_fn(module, act, grad * mult, reset=(i == 0))
        self._reset()