def _self_influence_batch_tracincp()

in captum/influence/_core/tracincp.py [0:0]


    def _self_influence_batch_tracincp(self, batch: Tuple[Any, ...]):
        """
        Computes self influence scores for a single batch
        """

        def get_checkpoint_contribution(checkpoint):

            assert (
                checkpoint is not None
            ), "None returned from `checkpoints`, cannot load."

            learning_rate = self.checkpoints_load_func(self.model, checkpoint)

            layer_jacobians = self._basic_computation_tracincp(batch[0:-1], batch[-1])

            # note that all variables in this function are for an entire batch.
            # each `layer_jacobian` in `layer_jacobians` corresponds to a different
            # layer. `layer_jacobian` is the jacobian w.r.t to a given layer's
            # parameters. if the given layer's parameters are of shape *, then
            # `layer_jacobian` is of shape (batch_size, *). for each layer, we need
            # the squared jacobian for each example. so we square the jacobian and
            # sum over all dimensions except the 0-th (the batch dimension). We then
            # sum the contribution over all layers.
            return (
                torch.sum(
                    torch.stack(
                        [
                            torch.sum(layer_jacobian.flatten(start_dim=1) ** 2, dim=1)
                            for layer_jacobian in layer_jacobians
                        ],
                        dim=0,
                    ),
                    dim=0,
                )
                * learning_rate
            )

        batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0])

        for checkpoint in self.checkpoints[1:]:
            batch_self_tracin_scores += get_checkpoint_contribution(checkpoint)

        return batch_self_tracin_scores