def _self_influence_batch_tracincp_fast()

in captum/influence/_core/tracincp_fast_rand_proj.py [0:0]


    def _self_influence_batch_tracincp_fast(self, batch: Tuple[Any, ...]):
        """
        Computes self influence scores for a single batch
        """

        def get_checkpoint_contribution(checkpoint):

            assert (
                checkpoint is not None
            ), "None returned from `checkpoints`, cannot load."

            learning_rate = self.checkpoints_load_func(self.model, checkpoint)

            batch_jacobian, batch_layer_input = _basic_computation_tracincp_fast(
                self, batch[0:-1], batch[-1]
            )

            return (
                torch.sum(batch_jacobian ** 2, dim=1)
                * torch.sum(batch_layer_input ** 2, dim=1)
                * learning_rate
            )

        batch_self_tracin_scores = get_checkpoint_contribution(self.checkpoints[0])

        for checkpoint in self.checkpoints[1:]:
            batch_self_tracin_scores += get_checkpoint_contribution(checkpoint)

        return batch_self_tracin_scores