in neuron_explainer/activation_server/derived_scalar_computation.py [0:0]
def loss_fn_for_backward_pass(output_logits: torch.Tensor) -> torch.Tensor:
assert output_logits.ndim == 3
nbatch, ntoken, nlogit = output_logits.shape
assert nbatch == 1
assert len(target_tokens_as_ints) > 0
target_mean = output_logits[:, -1, target_tokens_as_ints].mean(-1)
if len(distractor_tokens_as_ints) == 0:
loss = target_mean.mean() # average logits for target tokens
if subtract_mean:
loss -= output_logits[:, -1, :].mean()
return loss
else:
assert (
not subtract_mean
), "subtract_mean not a meaningful option when distractor_tokens is specified"
distractor_mean = output_logits[:, -1, distractor_tokens_as_ints].mean(-1)
return (
target_mean - distractor_mean
).mean() # difference between average logits for target and distractor tokens