neuron_explainer/explanations/explainer.py (419 lines of code) (raw):
"""Uses API calls to generate explanations of neuron behavior."""
from __future__ import annotations
import logging
import re
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Sequence
import numpy as np
from neuron_explainer.activations.activation_records import (
calculate_max_activation,
format_activation_records,
non_zero_activation_proportion,
)
from neuron_explainer.activations.activations import ActivationRecord
from neuron_explainer.activations.attention_utils import (
convert_flattened_index_to_unflattened_index,
)
from neuron_explainer.api_client import ApiClient
from neuron_explainer.explanations.few_shot_examples import (
ATTENTION_HEAD_FEW_SHOT_EXAMPLES,
AttentionTokenPairExample,
FewShotExampleSet,
)
from neuron_explainer.explanations.prompt_builder import (
ChatMessage,
PromptBuilder,
PromptFormat,
Role,
)
logger = logging.getLogger(__name__)
EXPLANATION_PREFIX = "this neuron activates for"
ATTENTION_EXPLANATION_PREFIX = "this attention head"
ATTENTION_SEQUENCE_SEPARATOR = "<|sequence_separator|>"
def _split_numbered_list(text: str) -> list[str]:
"""Split a numbered list into a list of strings."""
lines = re.split(r"\n\d+\.", text)
# Strip the leading whitespace from each line.
return [line.lstrip() for line in lines]
class ContextSize(int, Enum):
TWO_K = 2049
FOUR_K = 4097
@classmethod
def from_int(cls, i: int) -> ContextSize:
for context_size in cls:
if context_size.value == i:
return context_size
raise ValueError(f"{i} is not a valid ContextSize")
class NeuronExplainer(ABC):
"""
Abstract base class for Explainer classes that generate explanations from subclass-specific
input data.
"""
def __init__(
self,
model_name: str,
prompt_format: PromptFormat = PromptFormat.CHAT_MESSAGES,
# This parameter lets us adjust the length of the prompt when we're generating explanations
# using older models with shorter context windows. In the future we can use it to experiment
# with longer context windows.
context_size: ContextSize = ContextSize.FOUR_K,
max_concurrent: int | None = 10,
cache: bool = False,
):
self.prompt_format = prompt_format
self.context_size = context_size
self.client = ApiClient(model_name=model_name, max_concurrent=max_concurrent, cache=cache)
async def generate_explanations(
self,
*,
num_samples: int = 1,
max_tokens: int = 60,
temperature: float = 1.0,
top_p: float = 1.0,
**prompt_kwargs: Any,
) -> list[Any]:
"""Generate explanations based on subclass-specific input data."""
prompt = self.make_explanation_prompt(max_tokens_for_completion=max_tokens, **prompt_kwargs)
generate_kwargs: dict[str, Any] = {
# Using a timeout prevents the explainer from hanging if the API server is overloaded.
"timeout": 60,
"n": num_samples,
"max_tokens": max_tokens,
"temperature": temperature,
"top_p": top_p,
}
if self.prompt_format == PromptFormat.CHAT_MESSAGES:
assert isinstance(prompt, list)
assert isinstance(prompt[0], dict) # Really a ChatMessage
generate_kwargs["messages"] = prompt
else:
assert isinstance(prompt, str)
generate_kwargs["prompt"] = prompt
response = await self.client.async_generate(**generate_kwargs)
logger.debug("response in generate_explanations is %s", response)
if self.prompt_format == PromptFormat.CHAT_MESSAGES:
explanations = [x["message"]["content"] for x in response["choices"]]
elif self.prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]:
explanations = [x["text"] for x in response["choices"]]
else:
raise ValueError(f"Unhandled prompt format {self.prompt_format}")
return self.postprocess_explanations(explanations, prompt_kwargs)
@abstractmethod
def make_explanation_prompt(self, **kwargs: Any) -> str | list[ChatMessage]:
"""
Create a prompt to send to the API to generate one or more explanations.
A prompt can be a simple string, or a list of ChatMessages, depending on the PromptFormat
used by this instance.
"""
...
def postprocess_explanations(
self, completions: list[str], prompt_kwargs: dict[str, Any]
) -> list[Any]:
"""Postprocess the completions returned by the API into a list of explanations."""
return completions # no-op by default
def _prompt_is_too_long(
self, prompt_builder: PromptBuilder, max_tokens_for_completion: int
) -> bool:
# We'll get a context size error if the prompt itself plus the maximum number of tokens for
# the completion is longer than the context size.
prompt_length = prompt_builder.prompt_length_in_tokens(self.prompt_format)
if prompt_length + max_tokens_for_completion > self.context_size.value:
print(
f"Prompt is too long: {prompt_length} + {max_tokens_for_completion} > "
f"{self.context_size.value}"
)
return True
return False
class TokenActivationPairExplainer(NeuronExplainer):
"""
Generate explanations of neuron behavior using a prompt with lists of token/activation pairs.
"""
def __init__(
self,
model_name: str,
prompt_format: PromptFormat = PromptFormat.CHAT_MESSAGES,
# This parameter lets us adjust the length of the prompt when we're generating explanations
# using older models with shorter context windows. In the future we can use it to experiment
# with 8k+ context windows.
context_size: ContextSize = ContextSize.FOUR_K,
few_shot_example_set: FewShotExampleSet = FewShotExampleSet.ORIGINAL,
repeat_non_zero_activations: bool = False,
max_concurrent: int | None = 10,
cache: bool = False,
):
super().__init__(
model_name=model_name,
prompt_format=prompt_format,
max_concurrent=max_concurrent,
cache=cache,
)
self.context_size = context_size
self.few_shot_example_set = few_shot_example_set
self.repeat_non_zero_activations = repeat_non_zero_activations
def make_explanation_prompt(self, **kwargs: Any) -> str | list[ChatMessage]:
original_kwargs = kwargs.copy()
all_activation_records: Sequence[ActivationRecord] = kwargs.pop("all_activations")
max_activation: float = kwargs.pop("max_activation")
kwargs.setdefault("numbered_list_of_n_explanations", None)
numbered_list_of_n_explanations: int | None = kwargs.pop("numbered_list_of_n_explanations")
if numbered_list_of_n_explanations is not None:
assert numbered_list_of_n_explanations > 0, numbered_list_of_n_explanations
# This parameter lets us dynamically shrink the prompt if our initial attempt to create it
# results in something that's too long. It's only implemented for the 4k context size.
kwargs.setdefault("omit_n_activation_records", 0)
omit_n_activation_records: int = kwargs.pop("omit_n_activation_records")
max_tokens_for_completion: int = kwargs.pop("max_tokens_for_completion")
assert not kwargs, f"Unexpected kwargs: {kwargs}"
prompt_builder = PromptBuilder()
prompt_builder.add_message(
Role.SYSTEM,
"We're studying neurons in a neural network. Each neuron looks for some particular "
"thing in a short document. Look at the parts of the document the neuron activates for "
"and summarize in a single sentence what the neuron is looking for. Don't list "
"examples of words.\n\nThe activation format is token<tab>activation. Activation "
"values range from 0 to 10. A neuron finding what it's looking for is represented by a "
"non-zero activation value. The higher the activation value, the stronger the match.",
)
few_shot_examples = self.few_shot_example_set.get_examples()
num_omitted_activation_records = 0
for i, few_shot_example in enumerate(few_shot_examples):
few_shot_activation_records = few_shot_example.activation_records
if self.context_size == ContextSize.TWO_K:
# If we're using a 2k context window, we only have room for one activation record
# per few-shot example. (Two few-shot examples with one activation record each seems
# to work better than one few-shot example with two activation records, in local
# testing.)
few_shot_activation_records = few_shot_activation_records[:1]
elif (
self.context_size == ContextSize.FOUR_K
and num_omitted_activation_records < omit_n_activation_records
):
# Drop the last activation record for this few-shot example to save tokens, assuming
# there are at least two activation records.
if len(few_shot_activation_records) > 1:
print(f"Warning: omitting activation record from few-shot example {i}")
few_shot_activation_records = few_shot_activation_records[:-1]
num_omitted_activation_records += 1
self._add_per_neuron_explanation_prompt(
prompt_builder,
few_shot_activation_records,
i,
calculate_max_activation(few_shot_example.activation_records),
numbered_list_of_n_explanations=numbered_list_of_n_explanations,
explanation=few_shot_example.explanation,
)
self._add_per_neuron_explanation_prompt(
prompt_builder,
# If we're using a 2k context window, we only have room for two of the activation
# records.
(
all_activation_records[:2]
if self.context_size == ContextSize.TWO_K
else all_activation_records
),
len(few_shot_examples),
max_activation,
numbered_list_of_n_explanations=numbered_list_of_n_explanations,
explanation=None,
)
# If the prompt is too long *and* we omitted the specified number of activation records, try
# again, omitting one more. (If we didn't make the specified number of omissions, we're out
# of opportunities to omit records, so we just return the prompt as-is.)
if (
self._prompt_is_too_long(prompt_builder, max_tokens_for_completion)
and num_omitted_activation_records == omit_n_activation_records
):
original_kwargs["omit_n_activation_records"] = omit_n_activation_records + 1
return self.make_explanation_prompt(**original_kwargs)
return prompt_builder.build(self.prompt_format)
def _add_per_neuron_explanation_prompt(
self,
prompt_builder: PromptBuilder,
activation_records: Sequence[ActivationRecord],
index: int,
max_activation: float,
# When set, this indicates that the prompt should solicit a numbered list of the given
# number of explanations, rather than a single explanation.
numbered_list_of_n_explanations: int | None,
explanation: str | None, # None means this is the end of the full prompt.
) -> None:
max_activation = calculate_max_activation(activation_records)
user_message = f"""
Neuron {index + 1}
Activations:{format_activation_records(activation_records, max_activation, omit_zeros=False)}"""
# We repeat the non-zero activations only if it was requested and if the proportion of
# non-zero activations isn't too high.
if (
self.repeat_non_zero_activations
and non_zero_activation_proportion(activation_records, max_activation) < 0.2
):
user_message += (
f"\nSame activations, but with all zeros filtered out:"
f"{format_activation_records(activation_records, max_activation, omit_zeros=True)}"
)
if numbered_list_of_n_explanations is None:
user_message += f"\nExplanation of neuron {index + 1} behavior:"
assistant_message = ""
# For the IF format, we want <|endofprompt|> to come before the explanation prefix.
if self.prompt_format == PromptFormat.INSTRUCTION_FOLLOWING:
assistant_message += f" {EXPLANATION_PREFIX}"
else:
user_message += f" {EXPLANATION_PREFIX}"
prompt_builder.add_message(Role.USER, user_message)
if explanation is not None:
assistant_message += f" {explanation}."
if assistant_message:
prompt_builder.add_message(Role.ASSISTANT, assistant_message)
else:
if explanation is None:
# For the final neuron, we solicit a numbered list of explanations.
prompt_builder.add_message(
Role.USER,
f"""\nHere are {numbered_list_of_n_explanations} possible explanations for neuron {index + 1} behavior, each beginning with "{EXPLANATION_PREFIX}":\n1. {EXPLANATION_PREFIX}""",
)
else:
# For the few-shot examples, we only present one explanation, but we present it as a
# numbered list.
prompt_builder.add_message(
Role.USER,
f"""\nHere is 1 possible explanation for neuron {index + 1} behavior, beginning with "{EXPLANATION_PREFIX}":\n1. {EXPLANATION_PREFIX}""",
)
prompt_builder.add_message(Role.ASSISTANT, f" {explanation}.")
def postprocess_explanations(
self, completions: list[str], prompt_kwargs: dict[str, Any]
) -> list[Any]:
"""Postprocess the explanations returned by the API"""
numbered_list_of_n_explanations = prompt_kwargs.get("numbered_list_of_n_explanations")
if numbered_list_of_n_explanations is None:
return completions
else:
all_explanations = []
for completion in completions:
for explanation in _split_numbered_list(completion):
if explanation.startswith(EXPLANATION_PREFIX):
explanation = explanation[len(EXPLANATION_PREFIX) :]
all_explanations.append(explanation.strip())
return all_explanations
def format_attention_head_token_pairs(
token_pair_examples: list[AttentionTokenPairExample], omit_zeros: bool = False
) -> str:
if omit_zeros:
return ", ".join(
[
", ".join(
[
f"({example.tokens[coords[1]]}, {example.tokens[coords[0]]})"
for coords in example.token_pair_coordinates
]
)
for example in token_pair_examples
]
)
else:
return f"\n{ATTENTION_SEQUENCE_SEPARATOR}\n".join(
[
f"\n{ATTENTION_SEQUENCE_SEPARATOR}\n".join(
[
f"{format_attention_head_token_pair_string(example.tokens, coords)}"
for coords in example.token_pair_coordinates
]
)
for example in token_pair_examples
]
)
def format_attention_head_token_pair_string(
token_list: list[str], pair_coordinates: tuple[int, int]
) -> str:
def format_activated_token(i: int, token: str) -> str:
if i == pair_coordinates[0] and i == pair_coordinates[1]:
return f"[[**{token}**]]" # from and to
if i == pair_coordinates[0]:
return f"[[{token}]]" # from
if i == pair_coordinates[1]:
return f"**{token}**" # to
return token
return "".join([format_activated_token(i, token) for i, token in enumerate(token_list)])
def get_top_attention_coordinates(
activation_records: list[ActivationRecord], top_k: int = 5
) -> list[tuple[int, float, tuple[int, int]]]:
candidates = []
for i, record in enumerate(activation_records):
top_activation_flat_indices = np.argsort(record.activations)[::-1][:top_k]
top_vals: list[float] = [record.activations[idx] for idx in top_activation_flat_indices]
top_coordinates = [
convert_flattened_index_to_unflattened_index(flat_index)
for flat_index in top_activation_flat_indices
]
candidates.extend(
[(i, top_val, coords) for top_val, coords in zip(top_vals, top_coordinates)]
)
return sorted(candidates, key=lambda x: x[1], reverse=True)[:top_k]
class AttentionHeadExplainer(NeuronExplainer):
"""
Generate explanations of attention head behavior using a prompt with lists of
strongly attending to/from token pairs.
Takes in NeuronRecord's corresponding to a single attention head. Extracts strongly
activating to/from token pairs.
"""
def __init__(
self,
model_name: str,
prompt_format: PromptFormat = PromptFormat.CHAT_MESSAGES,
# This parameter lets us adjust the length of the prompt when we're generating explanations
# using older models with shorter context windows. In the future we can use it to experiment
# with 8k+ context windows.
context_size: ContextSize = ContextSize.FOUR_K,
repeat_strongly_attending_pairs: bool = False,
max_concurrent: int | None = 10,
cache: bool = False,
):
super().__init__(
model_name=model_name,
prompt_format=prompt_format,
max_concurrent=max_concurrent,
cache=cache,
)
assert (
context_size != ContextSize.TWO_K
), "2k context size not supported for attention explanation"
self.context_size = context_size
self.repeat_strongly_attending_pairs = repeat_strongly_attending_pairs
def make_explanation_prompt(self, **kwargs: Any) -> str | list[ChatMessage]:
original_kwargs = kwargs.copy()
all_activation_records: list[ActivationRecord] = kwargs.pop("all_activations")
# This parameter lets us dynamically shrink the prompt if our initial attempt to create it
# results in something that's too long.
kwargs.setdefault("omit_n_token_pair_examples", 0)
omit_n_token_pair_examples: int = kwargs.pop("omit_n_token_pair_examples")
max_tokens_for_completion: int = kwargs.pop("max_tokens_for_completion")
kwargs.setdefault("num_top_pairs_to_display", 0)
num_top_pairs_to_display: int = kwargs.pop("num_top_pairs_to_display")
assert not kwargs, f"Unexpected kwargs: {kwargs}"
prompt_builder = PromptBuilder()
prompt_builder.add_message(
Role.SYSTEM,
"We're studying attention heads in a neural network. Each head looks at every pair of tokens "
"in a short token sequence and activates for pairs of tokens that fit what it is looking for. "
"Attention heads always attend from a token to a token earlier in the sequence (or from a "
'token to itself). We will display multiple instances of sequences with the "to" token '
'surrounded by double asterisks (e.g., **token**) and the "from" token surrounded by double '
"square brackets (e.g., [[token]]). If a token attends from itself to itself, it will be "
"surrounded by both (e.g., [[**token**]]). Look at the pairs of tokens the head activates for "
"and summarize in a single sentence what pattern the head is looking for. We do not display "
"every activating pair of tokens in a sequence; you must generalize from limited examples. "
"Remember, the head always attends to tokens earlier in the sentence (marked with ** **) from "
"tokens later in the sentence (marked with [[ ]]), except when the head attends from a token to "
'itself (marked with [[** **]]). The explanation takes the form: "This attention head attends '
"to {pattern of tokens marked with ** **, which appear earlier} from {pattern of tokens marked with "
'[[ ]], which appear later}." The explanation does not include any of the markers (** **, [[ ]]), '
f"as these are just for your reference. Sequences are separated by `{ATTENTION_SEQUENCE_SEPARATOR}`.",
)
num_omitted_token_pair_examples = 0
for i, few_shot_example in enumerate(ATTENTION_HEAD_FEW_SHOT_EXAMPLES):
few_shot_token_pair_examples = few_shot_example.token_pair_examples
if num_omitted_token_pair_examples < omit_n_token_pair_examples:
# Drop the last activation record for this few-shot example to save tokens, assuming
# there are at least two activation records.
if len(few_shot_token_pair_examples) > 1:
print(f"Warning: omitting activation record from few-shot example {i}")
few_shot_token_pair_examples = few_shot_token_pair_examples[:-1]
num_omitted_token_pair_examples += 1
few_shot_explanation: str = few_shot_example.explanation
self._add_per_head_explanation_prompt(
prompt_builder,
few_shot_token_pair_examples,
i,
explanation=few_shot_explanation,
)
# each element is (record_index, attention value, (from_token_index, to_token_index))
coords = get_top_attention_coordinates(
all_activation_records, top_k=num_top_pairs_to_display
)
prompt_examples = {}
for record_index, _, (from_token_index, to_token_index) in coords:
if record_index not in prompt_examples:
prompt_examples[record_index] = AttentionTokenPairExample(
tokens=all_activation_records[record_index].tokens,
token_pair_coordinates=[(from_token_index, to_token_index)],
)
else:
prompt_examples[record_index].token_pair_coordinates.append(
(from_token_index, to_token_index)
)
current_head_token_pair_examples = list(prompt_examples.values())
self._add_per_head_explanation_prompt(
prompt_builder,
current_head_token_pair_examples,
len(ATTENTION_HEAD_FEW_SHOT_EXAMPLES),
explanation=None,
)
# If the prompt is too long *and* we omitted the specified number of activation records, try
# again, omitting one more. (If we didn't make the specified number of omissions, we're out
# of opportunities to omit records, so we just return the prompt as-is.)
if (
self._prompt_is_too_long(prompt_builder, max_tokens_for_completion)
and num_omitted_token_pair_examples == omit_n_token_pair_examples
):
original_kwargs["omit_n_token_pair_examples"] = omit_n_token_pair_examples + 1
return self.make_explanation_prompt(**original_kwargs)
return prompt_builder.build(self.prompt_format)
def _add_per_head_explanation_prompt(
self,
prompt_builder: PromptBuilder,
token_pair_examples: list[
AttentionTokenPairExample
], # each dict has keys "tokens" and "token_pair_coordinates"
index: int,
explanation: str | None, # None means this is the end of the full prompt.
) -> None:
user_message = f"""
Attention head {index + 1}
Activations:\n{format_attention_head_token_pairs(token_pair_examples, omit_zeros=False)}"""
if self.repeat_strongly_attending_pairs:
user_message += (
f"\nThe same list of strongly activating token pairs, presented as (to_token, from_token):"
f"{format_attention_head_token_pairs(token_pair_examples, omit_zeros=True)}"
)
user_message += f"\nExplanation of attention head {index + 1} behavior:"
assistant_message = ""
# For the IF format, we want <|endofprompt|> to come before the explanation prefix.
if self.prompt_format == PromptFormat.INSTRUCTION_FOLLOWING:
assistant_message += f" {ATTENTION_EXPLANATION_PREFIX}"
else:
user_message += f" {ATTENTION_EXPLANATION_PREFIX}"
prompt_builder.add_message(Role.USER, user_message)
if explanation is not None:
assistant_message += f" {explanation}."
if assistant_message:
prompt_builder.add_message(Role.ASSISTANT, assistant_message)