def loss_fn_for_backward_pass()

in neuron_explainer/activation_server/derived_scalar_computation.py [0:0]


    def loss_fn_for_backward_pass(output_logits: torch.Tensor) -> torch.Tensor:
        assert output_logits.ndim == 3
        nbatch, ntoken, nlogit = output_logits.shape
        assert nbatch == 1
        assert len(target_tokens_as_ints) > 0
        target_mean = output_logits[:, -1, target_tokens_as_ints].mean(-1)
        if len(distractor_tokens_as_ints) == 0:
            loss = target_mean.mean()  # average logits for target tokens
            if subtract_mean:
                loss -= output_logits[:, -1, :].mean()
            return loss
        else:
            assert (
                not subtract_mean
            ), "subtract_mean not a meaningful option when distractor_tokens is specified"
            distractor_mean = output_logits[:, -1, distractor_tokens_as_ints].mean(-1)
            return (
                target_mean - distractor_mean
            ).mean()  # difference between average logits for target and distractor tokens