in neuron_explainer/explanations/explainer.py [0:0]
def make_explanation_prompt(self, **kwargs: Any) -> str | list[ChatMessage]:
original_kwargs = kwargs.copy()
all_activation_records: list[ActivationRecord] = kwargs.pop("all_activations")
# This parameter lets us dynamically shrink the prompt if our initial attempt to create it
# results in something that's too long.
kwargs.setdefault("omit_n_token_pair_examples", 0)
omit_n_token_pair_examples: int = kwargs.pop("omit_n_token_pair_examples")
max_tokens_for_completion: int = kwargs.pop("max_tokens_for_completion")
kwargs.setdefault("num_top_pairs_to_display", 0)
num_top_pairs_to_display: int = kwargs.pop("num_top_pairs_to_display")
assert not kwargs, f"Unexpected kwargs: {kwargs}"
prompt_builder = PromptBuilder()
prompt_builder.add_message(
Role.SYSTEM,
"We're studying attention heads in a neural network. Each head looks at every pair of tokens "
"in a short token sequence and activates for pairs of tokens that fit what it is looking for. "
"Attention heads always attend from a token to a token earlier in the sequence (or from a "
'token to itself). We will display multiple instances of sequences with the "to" token '
'surrounded by double asterisks (e.g., **token**) and the "from" token surrounded by double '
"square brackets (e.g., [[token]]). If a token attends from itself to itself, it will be "
"surrounded by both (e.g., [[**token**]]). Look at the pairs of tokens the head activates for "
"and summarize in a single sentence what pattern the head is looking for. We do not display "
"every activating pair of tokens in a sequence; you must generalize from limited examples. "
"Remember, the head always attends to tokens earlier in the sentence (marked with ** **) from "
"tokens later in the sentence (marked with [[ ]]), except when the head attends from a token to "
'itself (marked with [[** **]]). The explanation takes the form: "This attention head attends '
"to {pattern of tokens marked with ** **, which appear earlier} from {pattern of tokens marked with "
'[[ ]], which appear later}." The explanation does not include any of the markers (** **, [[ ]]), '
f"as these are just for your reference. Sequences are separated by `{ATTENTION_SEQUENCE_SEPARATOR}`.",
)
num_omitted_token_pair_examples = 0
for i, few_shot_example in enumerate(ATTENTION_HEAD_FEW_SHOT_EXAMPLES):
few_shot_token_pair_examples = few_shot_example.token_pair_examples
if num_omitted_token_pair_examples < omit_n_token_pair_examples:
# Drop the last activation record for this few-shot example to save tokens, assuming
# there are at least two activation records.
if len(few_shot_token_pair_examples) > 1:
print(f"Warning: omitting activation record from few-shot example {i}")
few_shot_token_pair_examples = few_shot_token_pair_examples[:-1]
num_omitted_token_pair_examples += 1
few_shot_explanation: str = few_shot_example.explanation
self._add_per_head_explanation_prompt(
prompt_builder,
few_shot_token_pair_examples,
i,
explanation=few_shot_explanation,
)
# each element is (record_index, attention value, (from_token_index, to_token_index))
coords = get_top_attention_coordinates(
all_activation_records, top_k=num_top_pairs_to_display
)
prompt_examples = {}
for record_index, _, (from_token_index, to_token_index) in coords:
if record_index not in prompt_examples:
prompt_examples[record_index] = AttentionTokenPairExample(
tokens=all_activation_records[record_index].tokens,
token_pair_coordinates=[(from_token_index, to_token_index)],
)
else:
prompt_examples[record_index].token_pair_coordinates.append(
(from_token_index, to_token_index)
)
current_head_token_pair_examples = list(prompt_examples.values())
self._add_per_head_explanation_prompt(
prompt_builder,
current_head_token_pair_examples,
len(ATTENTION_HEAD_FEW_SHOT_EXAMPLES),
explanation=None,
)
# If the prompt is too long *and* we omitted the specified number of activation records, try
# again, omitting one more. (If we didn't make the specified number of omissions, we're out
# of opportunities to omit records, so we just return the prompt as-is.)
if (
self._prompt_is_too_long(prompt_builder, max_tokens_for_completion)
and num_omitted_token_pair_examples == omit_n_token_pair_examples
):
original_kwargs["omit_n_token_pair_examples"] = omit_n_token_pair_examples + 1
return self.make_explanation_prompt(**original_kwargs)
return prompt_builder.build(self.prompt_format)