neuron-explainer/neuron_explainer/explanations/explainer.py (348 lines of code) (raw):

"""Uses API calls to generate explanations of neuron behavior.""" from __future__ import annotations import logging import re from abc import ABC, abstractmethod from enum import Enum from typing import Any, Optional, Sequence, Union from neuron_explainer.activations.activation_records import ( calculate_max_activation, format_activation_records, non_zero_activation_proportion, ) from neuron_explainer.activations.activations import ActivationRecord from neuron_explainer.api_client import ApiClient from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet from neuron_explainer.explanations.prompt_builder import ( HarmonyMessage, PromptBuilder, PromptFormat, Role, ) from neuron_explainer.explanations.token_space_few_shot_examples import ( TokenSpaceFewShotExampleSet, ) logger = logging.getLogger(__name__) # TODO(williamrs): This prefix may not work well for some things, like predicting the next token. # Try other options like "this neuron activates for". EXPLANATION_PREFIX = "the main thing this neuron does is find" def _split_numbered_list(text: str) -> list[str]: """Split a numbered list into a list of strings.""" lines = re.split(r"\n\d+\.", text) # Strip the leading whitespace from each line. return [line.lstrip() for line in lines] def _remove_final_period(text: str) -> str: """Strip a final period or period-space from a string.""" if text.endswith("."): return text[:-1] elif text.endswith(". "): return text[:-2] return text class ContextSize(int, Enum): TWO_K = 2049 FOUR_K = 4097 @classmethod def from_int(cls, i: int) -> ContextSize: for context_size in cls: if context_size.value == i: return context_size raise ValueError(f"{i} is not a valid ContextSize") HARMONY_V4_MODELS = ["gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"] class NeuronExplainer(ABC): """ Abstract base class for Explainer classes that generate explanations from subclass-specific input data. """ def __init__( self, model_name: str, prompt_format: PromptFormat = PromptFormat.HARMONY_V4, # This parameter lets us adjust the length of the prompt when we're generating explanations # using older models with shorter context windows. In the future we can use it to experiment # with longer context windows. context_size: ContextSize = ContextSize.FOUR_K, max_concurrent: Optional[int] = 10, cache: bool = False, ): if prompt_format == PromptFormat.HARMONY_V4: assert model_name in HARMONY_V4_MODELS elif prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: assert model_name not in HARMONY_V4_MODELS else: raise ValueError(f"Unhandled prompt format {prompt_format}") self.model_name = model_name self.prompt_format = prompt_format self.context_size = context_size self.client = ApiClient(model_name=model_name, max_concurrent=max_concurrent, cache=cache) async def generate_explanations( self, *, num_samples: int = 5, max_tokens: int = 60, temperature: float = 1.0, top_p: float = 1.0, **prompt_kwargs: Any, ) -> list[Any]: """Generate explanations based on subclass-specific input data.""" prompt = self.make_explanation_prompt(max_tokens_for_completion=max_tokens, **prompt_kwargs) generate_kwargs: dict[str, Any] = { "n": num_samples, "max_tokens": max_tokens, "temperature": temperature, "top_p": top_p, } if self.prompt_format == PromptFormat.HARMONY_V4: assert isinstance(prompt, list) assert isinstance(prompt[0], dict) # Really a HarmonyMessage generate_kwargs["messages"] = prompt else: assert isinstance(prompt, str) generate_kwargs["prompt"] = prompt response = await self.client.make_request(**generate_kwargs) logger.debug("response in generate_explanations is %s", response) if self.prompt_format == PromptFormat.HARMONY_V4: explanations = [x["message"]["content"] for x in response["choices"]] elif self.prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: explanations = [x["text"] for x in response["choices"]] else: raise ValueError(f"Unhandled prompt format {self.prompt_format}") return self.postprocess_explanations(explanations, prompt_kwargs) @abstractmethod def make_explanation_prompt(self, **kwargs: Any) -> Union[str, list[HarmonyMessage]]: """ Create a prompt to send to the API to generate one or more explanations. A prompt can be a simple string, or a list of HarmonyMessages, depending on the PromptFormat used by this instance. """ ... def postprocess_explanations( self, completions: list[str], prompt_kwargs: dict[str, Any] ) -> list[Any]: """Postprocess the completions returned by the API into a list of explanations.""" return completions # no-op by default def _prompt_is_too_long( self, prompt_builder: PromptBuilder, max_tokens_for_completion: int ) -> bool: # We'll get a context size error if the prompt itself plus the maximum number of tokens for # the completion is longer than the context size. prompt_length = prompt_builder.prompt_length_in_tokens(self.prompt_format) if prompt_length + max_tokens_for_completion > self.context_size.value: print( f"Prompt is too long: {prompt_length} + {max_tokens_for_completion} > " f"{self.context_size.value}" ) return True return False class TokenActivationPairExplainer(NeuronExplainer): """ Generate explanations of neuron behavior using a prompt with lists of token/activation pairs. """ def __init__( self, model_name: str, prompt_format: PromptFormat = PromptFormat.HARMONY_V4, # This parameter lets us adjust the length of the prompt when we're generating explanations # using older models with shorter context windows. In the future we can use it to experiment # with 8k+ context windows. context_size: ContextSize = ContextSize.FOUR_K, few_shot_example_set: FewShotExampleSet = FewShotExampleSet.ORIGINAL, repeat_non_zero_activations: bool = True, max_concurrent: Optional[int] = 10, cache: bool = False, ): super().__init__( model_name=model_name, prompt_format=prompt_format, max_concurrent=max_concurrent, cache=cache, ) self.context_size = context_size self.few_shot_example_set = few_shot_example_set self.repeat_non_zero_activations = repeat_non_zero_activations def make_explanation_prompt(self, **kwargs: Any) -> Union[str, list[HarmonyMessage]]: original_kwargs = kwargs.copy() all_activation_records: Sequence[ActivationRecord] = kwargs.pop("all_activation_records") max_activation: float = kwargs.pop("max_activation") kwargs.setdefault("numbered_list_of_n_explanations", None) numbered_list_of_n_explanations: Optional[int] = kwargs.pop( "numbered_list_of_n_explanations" ) if numbered_list_of_n_explanations is not None: assert numbered_list_of_n_explanations > 0, numbered_list_of_n_explanations # This parameter lets us dynamically shrink the prompt if our initial attempt to create it # results in something that's too long. It's only implemented for the 4k context size. kwargs.setdefault("omit_n_activation_records", 0) omit_n_activation_records: int = kwargs.pop("omit_n_activation_records") max_tokens_for_completion: int = kwargs.pop("max_tokens_for_completion") assert not kwargs, f"Unexpected kwargs: {kwargs}" prompt_builder = PromptBuilder() prompt_builder.add_message( Role.SYSTEM, "We're studying neurons in a neural network. Each neuron looks for some particular " "thing in a short document. Look at the parts of the document the neuron activates for " "and summarize in a single sentence what the neuron is looking for. Don't list " "examples of words.\n\nThe activation format is token<tab>activation. Activation " "values range from 0 to 10. A neuron finding what it's looking for is represented by a " "non-zero activation value. The higher the activation value, the stronger the match.", ) few_shot_examples = self.few_shot_example_set.get_examples() num_omitted_activation_records = 0 for i, few_shot_example in enumerate(few_shot_examples): few_shot_activation_records = few_shot_example.activation_records if self.context_size == ContextSize.TWO_K: # If we're using a 2k context window, we only have room for one activation record # per few-shot example. (Two few-shot examples with one activation record each seems # to work better than one few-shot example with two activation records, in local # testing.) few_shot_activation_records = few_shot_activation_records[:1] elif ( self.context_size == ContextSize.FOUR_K and num_omitted_activation_records < omit_n_activation_records ): # Drop the last activation record for this few-shot example to save tokens, assuming # there are at least two activation records. if len(few_shot_activation_records) > 1: print(f"Warning: omitting activation record from few-shot example {i}") few_shot_activation_records = few_shot_activation_records[:-1] num_omitted_activation_records += 1 self._add_per_neuron_explanation_prompt( prompt_builder, few_shot_activation_records, i, calculate_max_activation(few_shot_example.activation_records), numbered_list_of_n_explanations=numbered_list_of_n_explanations, explanation=few_shot_example.explanation, ) self._add_per_neuron_explanation_prompt( prompt_builder, # If we're using a 2k context window, we only have room for two of the activation # records. all_activation_records[:2] if self.context_size == ContextSize.TWO_K else all_activation_records, len(few_shot_examples), max_activation, numbered_list_of_n_explanations=numbered_list_of_n_explanations, explanation=None, ) # If the prompt is too long *and* we omitted the specified number of activation records, try # again, omitting one more. (If we didn't make the specified number of omissions, we're out # of opportunities to omit records, so we just return the prompt as-is.) if ( self._prompt_is_too_long(prompt_builder, max_tokens_for_completion) and num_omitted_activation_records == omit_n_activation_records ): original_kwargs["omit_n_activation_records"] = omit_n_activation_records + 1 return self.make_explanation_prompt(**original_kwargs) return prompt_builder.build(self.prompt_format) def _add_per_neuron_explanation_prompt( self, prompt_builder: PromptBuilder, activation_records: Sequence[ActivationRecord], index: int, max_activation: float, # When set, this indicates that the prompt should solicit a numbered list of the given # number of explanations, rather than a single explanation. numbered_list_of_n_explanations: Optional[int], explanation: Optional[str], # None means this is the end of the full prompt. ) -> None: max_activation = calculate_max_activation(activation_records) user_message = f""" Neuron {index + 1} Activations:{format_activation_records(activation_records, max_activation, omit_zeros=False)}""" # We repeat the non-zero activations only if it was requested and if the proportion of # non-zero activations isn't too high. if ( self.repeat_non_zero_activations and non_zero_activation_proportion(activation_records, max_activation) < 0.2 ): user_message += ( f"\nSame activations, but with all zeros filtered out:" f"{format_activation_records(activation_records, max_activation, omit_zeros=True)}" ) if numbered_list_of_n_explanations is None: user_message += f"\nExplanation of neuron {index + 1} behavior:" assistant_message = "" # For the IF format, we want <|endofprompt|> to come before the explanation prefix. if self.prompt_format == PromptFormat.INSTRUCTION_FOLLOWING: assistant_message += f" {EXPLANATION_PREFIX}" else: user_message += f" {EXPLANATION_PREFIX}" prompt_builder.add_message(Role.USER, user_message) if explanation is not None: assistant_message += f" {explanation}." if assistant_message: prompt_builder.add_message(Role.ASSISTANT, assistant_message) else: if explanation is None: # For the final neuron, we solicit a numbered list of explanations. prompt_builder.add_message( Role.USER, f"""\nHere are {numbered_list_of_n_explanations} possible explanations for neuron {index + 1} behavior, each beginning with "{EXPLANATION_PREFIX}":\n1. {EXPLANATION_PREFIX}""", ) else: # For the few-shot examples, we only present one explanation, but we present it as a # numbered list. prompt_builder.add_message( Role.USER, f"""\nHere is 1 possible explanation for neuron {index + 1} behavior, beginning with "{EXPLANATION_PREFIX}":\n1. {EXPLANATION_PREFIX}""", ) prompt_builder.add_message(Role.ASSISTANT, f" {explanation}.") def postprocess_explanations( self, completions: list[str], prompt_kwargs: dict[str, Any] ) -> list[Any]: """Postprocess the explanations returned by the API""" numbered_list_of_n_explanations = prompt_kwargs.get("numbered_list_of_n_explanations") if numbered_list_of_n_explanations is None: return completions else: all_explanations = [] for completion in completions: for explanation in _split_numbered_list(completion): if explanation.startswith(EXPLANATION_PREFIX): explanation = explanation[len(EXPLANATION_PREFIX) :] all_explanations.append(explanation.strip()) return all_explanations class TokenSpaceRepresentationExplainer(NeuronExplainer): """ Generate explanations of arbitrary lists of tokens which disproportionately activate a particular neuron. These lists of tokens can be generated in various ways. As an example, in one set of experiments, we compute the average activation for each neuron conditional on each token that appears in an internet text corpus. We then sort the tokens by their average activation, and show 50 of the top 100 tokens. Other techniques that could be used include taking the top tokens in the logit lens or tuned lens representations of a neuron. """ def __init__( self, model_name: str, prompt_format: PromptFormat = PromptFormat.HARMONY_V4, context_size: ContextSize = ContextSize.FOUR_K, few_shot_example_set: TokenSpaceFewShotExampleSet = TokenSpaceFewShotExampleSet.ORIGINAL, use_few_shot: bool = False, output_numbered_list: bool = False, max_concurrent: Optional[int] = 10, cache: bool = False, ): super().__init__( model_name=model_name, prompt_format=prompt_format, context_size=context_size, max_concurrent=max_concurrent, cache=cache, ) self.use_few_shot = use_few_shot self.output_numbered_list = output_numbered_list if self.use_few_shot: assert few_shot_example_set is not None self.few_shot_examples: Optional[TokenSpaceFewShotExampleSet] = few_shot_example_set else: self.few_shot_examples = None self.prompt_prefix = ( "We're studying neurons in a neural network. Each neuron looks for some particular " "kind of token (which can be a word, or part of a word). Look at the tokens the neuron " "activates for (listed below) and summarize in a single sentence what the neuron is " "looking for. Don't list examples of words." ) def make_explanation_prompt(self, **kwargs: Any) -> Union[str, list[HarmonyMessage]]: tokens: list[str] = kwargs.pop("tokens") max_tokens_for_completion = kwargs.pop("max_tokens_for_completion") assert not kwargs, f"Unexpected kwargs: {kwargs}" # Note that this does not preserve the precise tokens, as e.g. # f" {token_with_no_leading_space}" may be tokenized as "f{token_with_leading_space}". # TODO(dan): Try out other variants, including "\n".join(...) and ",".join(...) stringified_tokens = ", ".join([f"'{t}'" for t in tokens]) prompt_builder = PromptBuilder() prompt_builder.add_message(Role.SYSTEM, self.prompt_prefix) if self.use_few_shot: self._add_few_shot_examples(prompt_builder) self._add_neuron_specific_prompt(prompt_builder, stringified_tokens, explanation=None) if self._prompt_is_too_long(prompt_builder, max_tokens_for_completion): raise ValueError(f"Prompt too long: {prompt_builder.build(self.prompt_format)}") else: return prompt_builder.build(self.prompt_format) def _add_few_shot_examples(self, prompt_builder: PromptBuilder) -> None: """ Append few-shot examples to the prompt. Each one consists of a comma-delimited list of tokens and corresponding explanations, as saved in alignment/neuron_explainer/weight_explainer/token_space_few_shot_examples.py. """ assert self.few_shot_examples is not None few_shot_example_list = self.few_shot_examples.get_examples() if self.output_numbered_list: raise NotImplementedError("Numbered list output not supported for few-shot examples") else: for few_shot_example in few_shot_example_list: self._add_neuron_specific_prompt( prompt_builder, ", ".join([f"'{t}'" for t in few_shot_example.tokens]), explanation=few_shot_example.explanation, ) def _add_neuron_specific_prompt( self, prompt_builder: PromptBuilder, stringified_tokens: str, explanation: Optional[str], ) -> None: """ Append a neuron-specific prompt to the prompt builder. The prompt consists of a list of tokens followed by either an explanation (if one is passed, for few shot examples) or by the beginning of a completion, to be completed by the model with an explanation. """ user_message = f"\n\n\n\nTokens:\n{stringified_tokens}\n\nExplanation:\n" assistant_message = "" looking_for = "This neuron is looking for" if self.prompt_format == PromptFormat.INSTRUCTION_FOLLOWING: # We want <|endofprompt|> to come before "This neuron is looking for" in the IF format. assistant_message += looking_for else: user_message += looking_for if self.output_numbered_list: start_of_list = "\n1." if self.prompt_format == PromptFormat.INSTRUCTION_FOLLOWING: assistant_message += start_of_list else: user_message += start_of_list if explanation is not None: assistant_message += f"{explanation}." prompt_builder.add_message(Role.USER, user_message) if assistant_message: prompt_builder.add_message(Role.ASSISTANT, assistant_message) def postprocess_explanations( self, completions: list[str], prompt_kwargs: dict[str, Any] ) -> list[str]: if self.output_numbered_list: # Each list in the top-level list will have multiple explanations (multiple strings). all_explanations = [] for completion in completions: for explanation in _split_numbered_list(completion): if explanation.startswith(EXPLANATION_PREFIX): explanation = explanation[len(EXPLANATION_PREFIX) :] all_explanations.append(explanation.strip()) return all_explanations else: # Each element in the top-level list will be an explanation as a string. return [_remove_final_period(explanation) for explanation in completions]