neuron_explainer/explanations/attention_head_scoring.py (205 lines of code) (raw):
"""Uses API calls to score attention head explanations."""
from __future__ import annotations
import random
from typing import Any
import numpy as np
from sklearn.metrics import roc_auc_score
from neuron_explainer.activations.activations import (
ActivationRecord,
ActivationRecordSliceParams,
load_neuron,
)
from neuron_explainer.activations.attention_utils import (
convert_flattened_index_to_unflattened_index,
)
from neuron_explainer.api_client import ApiClient
from neuron_explainer.explanations.explainer import (
ATTENTION_EXPLANATION_PREFIX,
ContextSize,
format_attention_head_token_pair_string,
)
from neuron_explainer.explanations.explanations import (
AttentionSimulation,
ScoredAttentionSimulation,
)
from neuron_explainer.explanations.few_shot_examples import ATTENTION_HEAD_FEW_SHOT_EXAMPLES
from neuron_explainer.explanations.prompt_builder import (
ChatMessage,
PromptBuilder,
PromptFormat,
Role,
)
class AttentionHeadOneAtATimeScorer:
def __init__(
self,
model_name: str,
prompt_format: PromptFormat = PromptFormat.CHAT_MESSAGES,
# This parameter lets us adjust the length of the prompt when we're generating explanations
# using older models with shorter context windows. In the future we can use it to experiment
# with longer context windows.
context_size: ContextSize = ContextSize.FOUR_K,
repeat_strongly_attending_pairs: bool = False,
max_concurrent: int | None = 10,
cache: bool = False,
):
assert (
prompt_format == PromptFormat.CHAT_MESSAGES
), f"Unhandled prompt format {prompt_format}"
self.prompt_format = prompt_format
self.context_size = context_size
self.client = ApiClient(model_name=model_name, max_concurrent=max_concurrent, cache=cache)
self.repeat_strongly_attending_pairs = repeat_strongly_attending_pairs
async def score_explanation(
self,
explanation: str,
activation_records: list[ActivationRecord],
max_activation: float,
# The number of high and low activating token pairs to sample for simulation
num_activations_for_scoring: int = 5,
# The activation threshold below which a token pair is eligible for sampling
# as a low activating pair.
low_activation_threshold: float = 0.1,
) -> ScoredAttentionSimulation:
"""Score explanations based on how well they predict attention between
top attending token pairs and random low attending token pairs."""
# Use the activation records to generate a set of pairs for scoring.
# 10 pairs: the five top activating pairs, and five randomly chosen pairs
# where the activations are below 0.1 * the max value.
candidates = []
for i, record in enumerate(activation_records):
sorted_activation_flat_indices = np.argsort(record.activations)[::-1]
sorted_vals = [record.activations[idx] for idx in sorted_activation_flat_indices]
coordinates = [
convert_flattened_index_to_unflattened_index(flat_index)
for flat_index in sorted_activation_flat_indices
]
candidates.extend([(i, val, coords) for val, coords in zip(sorted_vals, coordinates)])
top_activation_coordinates = [
(candidate[0], candidate[2])
for candidate in sorted(candidates, key=lambda x: x[1], reverse=True)
][:num_activations_for_scoring]
filtered_low_activation_coordinates = [
(candidate[0], candidate[2])
for candidate in candidates
if candidate[1] < low_activation_threshold * max_activation
]
selected_low_activation_coordinates = random.sample(
filtered_low_activation_coordinates,
min(len(filtered_low_activation_coordinates), num_activations_for_scoring),
)
attention_simulations = []
true_labels = [1 for _ in range(len(top_activation_coordinates))] + [
0 for _ in range(len(selected_low_activation_coordinates))
]
# No need to shuffle because the model only sees one at a time anyway.
for coords, label in zip(
top_activation_coordinates + selected_low_activation_coordinates, true_labels
):
activation_record = activation_records[coords[0]]
# for each pair, generate a prompt where the model is asked to predict if the token pair has a strong
# or weak activation.
prompt = self.make_token_pair_prompt(explanation, activation_record.tokens, coords[1])
assert isinstance(prompt, list)
assert isinstance(prompt[0], dict) # Really a ChatMessage
generate_kwargs: dict[str, Any] = {
# Using a timeout prevents the scorer from hanging if the API server is overloaded.
"timeout": 60,
"n": 1,
"max_tokens": 1, # we only want to sample one token.
"logprobs": True,
"top_logprobs": 15,
"messages": prompt,
}
response = await self.client.async_generate(**generate_kwargs)
assert len(response["choices"]) == 1
# from the response, extract the logit values for "0" (for weak) and "1" (for strong) to obtain
# a float.
choice = response["choices"][0]
# for whatever reason `choice["logprobs"]["top_logprobs"]` is a list of dicts
logprobs_dicts = choice["logprobs"]["content"][0]["top_logprobs"]
extracted_probs = {d["token"]: d["logprob"] for d in logprobs_dicts}
zero_prob = np.exp(extracted_probs["0"]) if "0" in extracted_probs else 0.0
one_prob = np.exp(extracted_probs["1"]) if "1" in extracted_probs else 0.0
total_prob = zero_prob + one_prob
# The score is 0 * normalized probability of "0" + 1 * normalized probability of "1", which
# reduces to just the normalized probability of "1".
normalized_one_prob = one_prob / total_prob
# print(f"zero_prob: {zero_prob/total_prob}, one_prob: {normalized_one_prob}")
attention_simulations.append(
AttentionSimulation(
tokens=activation_record.tokens,
token_pair_coords=coords[1],
token_pair_label=label,
simulation_prediction=normalized_one_prob,
)
)
assert (
len(attention_simulations)
== len(true_labels)
== len(top_activation_coordinates) + len(selected_low_activation_coordinates)
)
# ROC AUC awards a perfect score to explanations that order all of the scores
# for pairs labeled "1" above the scores for pairs labeled "0" (even if the scores
# for pairs labeled "0" are well above 0).
score = roc_auc_score(
y_true=true_labels, y_score=[sim.simulation_prediction for sim in attention_simulations]
)
return ScoredAttentionSimulation(
attention_simulations=attention_simulations,
roc_auc_score=score,
)
def make_token_pair_prompt(
self, explanation: str, tokens: list[str], coords: tuple[int, int]
) -> str | list[ChatMessage]:
"""
Create a prompt to send to the API to simulate the model predicting whether a token pair
has a strong attention write norm according to the given explanation.
"""
prompt_builder = PromptBuilder()
prompt_builder.add_message(
Role.SYSTEM,
"We're studying attention heads in a neural network. Each head looks at every pair of tokens "
"in a short token sequence and activates for pairs of tokens that fit what it is looking for. "
"Attention heads always attend from a token to a token earlier in the sequence (or from a "
"token to itself). We will display a token sequence and indicate a particular token pair within "
'that sequence. The "to" token of the pair will be marked with double asterisks (e.g., **token**) '
'and the "from" token will be marked with double square brackets (e.g., [[token]]). If the token pair '
"consists of a token paired with itself, it will be marked with both (e.g., [[**token**]]) and "
"no other token in the sequence will be marked. We present an explanation of what the "
"attention head is looking for. Output 1 if the head activates for the token pair, and 0 otherwise.",
)
num_few_shot = 0
for few_shot_example in ATTENTION_HEAD_FEW_SHOT_EXAMPLES:
if not few_shot_example.simulation_examples:
continue
for simulation_example in few_shot_example.simulation_examples:
self._add_per_token_pair_attention_simulation_prompt(
prompt_builder=prompt_builder,
tokens=few_shot_example.token_pair_examples[
simulation_example.token_pair_example_index
].tokens,
explanation=few_shot_example.explanation,
simulation_coords=simulation_example.token_pair_coordinates,
index=num_few_shot,
label=simulation_example.label,
)
num_few_shot += 1
self._add_per_token_pair_attention_simulation_prompt(
prompt_builder=prompt_builder,
tokens=tokens,
explanation=explanation,
simulation_coords=coords,
index=num_few_shot,
label=None,
)
return prompt_builder.build(self.prompt_format)
def _add_per_token_pair_attention_simulation_prompt(
self,
prompt_builder: PromptBuilder,
tokens: list[str],
explanation: str,
simulation_coords: tuple[int, int],
index: int,
label: int | None, # None means this is the end of the full prompt.
) -> None:
user_message = f"""
Example {index + 1}
Explanation: {ATTENTION_EXPLANATION_PREFIX} {explanation.strip()}
Sequence:\n{format_attention_head_token_pair_string(tokens, simulation_coords)}"""
if self.repeat_strongly_attending_pairs:
user_message += (
f"\nThe same token pair, presented as (to_token, from_token): "
f"({tokens[simulation_coords[1]]}, {tokens[simulation_coords[0]]})"
)
user_message += (
f"\nPrediction of whether attention head {index + 1} activates on the token pair: "
)
prompt_builder.add_message(Role.USER, user_message)
if label is not None:
prompt_builder.add_message(Role.ASSISTANT, f"{label}")
if __name__ == "__main__":
# Example usage
async def main() -> None:
scorer = AttentionHeadOneAtATimeScorer("gpt-4o")
explanation = "attends from tokens to the first token in the sequence"
attention_head = load_neuron(
"https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/attn_write_norm/collated_activations_by_token_pair",
"0",
"5",
)
train_records = attention_head.train_activation_records(
activation_record_slice_params=ActivationRecordSliceParams(n_examples_per_split=5)
)
scored_simulation = await scorer.score_explanation(
explanation, train_records, max([max(record.activations) for record in train_records])
)
print(scored_simulation.roc_auc_score)
import asyncio
asyncio.run(main())