neuron-explainer/neuron_explainer/explanations/simulator.py (546 lines of code) (raw):

"""Uses API calls to simulate neuron activations based on an explanation.""" from __future__ import annotations import asyncio import logging from abc import ABC, abstractmethod from collections import OrderedDict from enum import Enum from typing import Any, Optional, Sequence, Union import numpy as np from neuron_explainer.activations.activation_records import ( calculate_max_activation, format_activation_records, format_sequences_for_simulation, normalize_activations, ) from neuron_explainer.activations.activations import ActivationRecord from neuron_explainer.api_client import ApiClient from neuron_explainer.explanations.explainer import EXPLANATION_PREFIX from neuron_explainer.explanations.explanations import ActivationScale, SequenceSimulation from neuron_explainer.explanations.few_shot_examples import FewShotExampleSet from neuron_explainer.explanations.prompt_builder import ( HarmonyMessage, PromptBuilder, PromptFormat, Role, ) logger = logging.getLogger(__name__) # Our prompts use normalized activation values, which map any range of positive activations to the # integers from 0 to 10. MAX_NORMALIZED_ACTIVATION = 10 VALID_ACTIVATION_TOKENS_ORDERED = list(str(i) for i in range(MAX_NORMALIZED_ACTIVATION + 1)) VALID_ACTIVATION_TOKENS = set(VALID_ACTIVATION_TOKENS_ORDERED) class SimulationType(str, Enum): """How to simulate neuron activations. Values correspond to subclasses of NeuronSimulator.""" ALL_AT_ONCE = "all_at_once" """ Use a single prompt with <unknown> tokens; calculate EVs using logprobs. Implemented by ExplanationNeuronSimulator. """ ONE_AT_A_TIME = "one_at_a_time" """ Use a separate prompt for each token being simulated; calculate EVs using logprobs. Implemented by ExplanationTokenByTokenSimulator. """ @classmethod def from_string(cls, s: str) -> SimulationType: for simulation_type in SimulationType: if simulation_type.value == s: return simulation_type raise ValueError(f"Invalid simulation type: {s}") def compute_expected_value( norm_probabilities_by_distribution_value: OrderedDict[int, float] ) -> float: """ Given a map from distribution values (integers on the range [0, 10]) to normalized probabilities, return an expected value for the distribution. """ return np.dot( np.array(list(norm_probabilities_by_distribution_value.keys())), np.array(list(norm_probabilities_by_distribution_value.values())), ) def parse_top_logprobs(top_logprobs: dict[str, float]) -> OrderedDict[int, float]: """ Given a map from tokens to logprobs, return a map from distribution values (integers on the range [0, 10]) to unnormalized probabilities (in the sense that they may not sum to 1). """ probabilities_by_distribution_value = OrderedDict() for token, logprob in top_logprobs.items(): if token in VALID_ACTIVATION_TOKENS: token_as_int = int(token) probabilities_by_distribution_value[token_as_int] = np.exp(logprob) return probabilities_by_distribution_value def compute_predicted_activation_stats_for_token( top_logprobs: dict[str, float], ) -> tuple[OrderedDict[int, float], float]: probabilities_by_distribution_value = parse_top_logprobs(top_logprobs) total_p_of_distribution_values = sum(probabilities_by_distribution_value.values()) norm_probabilities_by_distribution_value = OrderedDict( { distribution_value: p / total_p_of_distribution_values for distribution_value, p in probabilities_by_distribution_value.items() } ) expected_value = compute_expected_value(norm_probabilities_by_distribution_value) return ( norm_probabilities_by_distribution_value, expected_value, ) # Adapted from tether/tether/core/encoder.py. def convert_to_byte_array(s: str) -> bytearray: byte_array = bytearray() assert s.startswith("bytes:"), s s = s[6:] while len(s) > 0: if s[0] == "\\": # Hex encoding. assert s[1] == "x" assert len(s) >= 4 byte_array.append(int(s[2:4], 16)) s = s[4:] else: # Regular ascii encoding. byte_array.append(ord(s[0])) s = s[1:] return byte_array def handle_byte_encoding( response_tokens: Sequence[str], merged_response_index: int ) -> tuple[str, int]: """ Handle the case where the current token is a sequence of bytes. This may involve merging multiple response tokens into a single token. """ response_token = response_tokens[merged_response_index] if response_token.startswith("bytes:"): byte_array = bytearray() while True: byte_array = convert_to_byte_array(response_token) + byte_array try: # If we can decode the byte array as utf-8, then we're done. response_token = byte_array.decode("utf-8") break except UnicodeDecodeError: # If not, then we need to merge the previous response token into the byte # array. merged_response_index -= 1 response_token = response_tokens[merged_response_index] return response_token, merged_response_index def was_token_split(current_token: str, response_tokens: Sequence[str], start_index: int) -> bool: """ Return whether current_token (a token from the subject model) was split into multiple tokens by the simulator model (as represented by the tokens in response_tokens). start_index is the index in response_tokens at which to begin looking backward to form a complete token. It is usually the first token *before* the delimiter that separates the token from the normalized activation, barring some unusual cases. This mainly happens if the subject model uses a different tokenizer than the simulator model. But it can also happen in cases where Unicode characters are split. This function handles both cases. """ merged_response_tokens = "" merged_response_index = start_index while len(merged_response_tokens) < len(current_token): response_token = response_tokens[merged_response_index] response_token, merged_response_index = handle_byte_encoding( response_tokens, merged_response_index ) merged_response_tokens = response_token + merged_response_tokens merged_response_index -= 1 # It's possible that merged_response_tokens is longer than current_token at this point, # since the between-lines delimiter may have been merged into the original token. But it # should always be the case that merged_response_tokens ends with current_token. assert merged_response_tokens.endswith(current_token) num_merged_tokens = start_index - merged_response_index token_was_split = num_merged_tokens > 1 if token_was_split: logger.debug( "Warning: token from the subject model was split into 2+ tokens by the simulator model." ) return token_was_split def parse_simulation_response( response: dict[str, Any], prompt_format: PromptFormat, tokens: Sequence[str], ) -> SequenceSimulation: """ Parse an API response to a simulation prompt. Args: response: response from the API prompt_format: how the prompt was formatted tokens: list of tokens as strings in the sequence where the neuron is being simulated """ choice = response["choices"][0] if prompt_format == PromptFormat.HARMONY_V4: text = choice["message"]["content"] elif prompt_format in [ PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING, ]: text = choice["text"] else: raise ValueError(f"Unhandled prompt format {prompt_format}") response_tokens = choice["logprobs"]["tokens"] choice["logprobs"]["token_logprobs"] top_logprobs = choice["logprobs"]["top_logprobs"] token_text_offset = choice["logprobs"]["text_offset"] # This only works because the sequence "<start>" tokenizes into multiple tokens if it appears in # a text sequence in the prompt. scoring_start = text.rfind("<start>") expected_values = [] original_sequence_tokens: list[str] = [] distribution_values: list[list[float]] = [] distribution_probabilities: list[list[float]] = [] for i in range(2, len(response_tokens)): if len(original_sequence_tokens) == len(tokens): # Make sure we haven't hit some sort of off-by-one error. # TODO(sbills): Generalize this to handle different tokenizers. reached_end = response_tokens[i + 1] == "<" and response_tokens[i + 2] == "end" assert reached_end, f"{response_tokens[i-3:i+3]}" break if token_text_offset[i] >= scoring_start: # We're looking for the first token after a tab. This token should be the text # "unknown" if hide_activations=True or a normalized activation (0-10) otherwise. # If it isn't, that means that the tab is not appearing as a delimiter, but rather # as a token, in which case we should move on to the next response token. if response_tokens[i - 1] == "\t": if response_tokens[i] != "unknown": logger.debug("Ignoring tab token that is not followed by an 'unknown' token.") continue # j represents the index of the token in a "token<tab>activation" line, barring # one of the unusual cases handled below. j = i - 2 current_token = tokens[len(original_sequence_tokens)] if current_token == response_tokens[j] or was_token_split( current_token, response_tokens, j ): # We're in the normal case where the tokenization didn't throw off the # formatting or in the token-was-split case, which we handle the usual way. current_top_logprobs = top_logprobs[i] ( norm_probabilities_by_distribution_value, expected_value, ) = compute_predicted_activation_stats_for_token( current_top_logprobs, ) current_distribution_values = list( norm_probabilities_by_distribution_value.keys() ) current_distribution_probabilities = list( norm_probabilities_by_distribution_value.values() ) else: # We're in a case where the tokenization resulted in a newline being folded into # the token. We can't do our usual prediction of activation stats for the token, # since the model did not observe the original token. Instead, we use dummy # values. See the TODO elsewhere in this file about coming up with a better # prompt format that avoids this situation. newline_folded_into_token = "\n" in response_tokens[j] assert ( newline_folded_into_token ), f"`{current_token=}` {response_tokens[j-3:j+3]=}" logger.debug( "Warning: newline before a token<tab>activation line was folded into the token" ) current_distribution_values = [] current_distribution_probabilities = [] expected_value = 0.0 original_sequence_tokens.append(current_token) distribution_values.append([float(v) for v in current_distribution_values]) distribution_probabilities.append(current_distribution_probabilities) expected_values.append(expected_value) return SequenceSimulation( tokens=original_sequence_tokens, expected_activations=expected_values, activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, distribution_values=distribution_values, distribution_probabilities=distribution_probabilities, ) class NeuronSimulator(ABC): """Abstract base class for simulating neuron behavior.""" @abstractmethod async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation: """Simulate the behavior of a neuron based on an explanation.""" ... class ExplanationNeuronSimulator(NeuronSimulator): """ Simulate neuron behavior based on an explanation. This class uses a few-shot prompt with examples of other explanations and activations. This prompt allows us to score all of the tokens at once using a nifty trick involving logprobs. """ def __init__( self, model_name: str, explanation: str, max_concurrent: Optional[int] = 10, few_shot_example_set: FewShotExampleSet = FewShotExampleSet.ORIGINAL, prompt_format: PromptFormat = PromptFormat.INSTRUCTION_FOLLOWING, cache: bool = False, ): self.api_client = ApiClient( model_name=model_name, max_concurrent=max_concurrent, cache=cache ) self.explanation = explanation self.few_shot_example_set = few_shot_example_set self.prompt_format = prompt_format async def simulate( self, tokens: Sequence[str], ) -> SequenceSimulation: prompt = self.make_simulation_prompt(tokens) generate_kwargs: dict[str, Any] = { "max_tokens": 0, "echo": True, "logprobs": 15, } if self.prompt_format == PromptFormat.HARMONY_V4: assert isinstance(prompt, list) assert isinstance(prompt[0], dict) # Really a HarmonyMessage generate_kwargs["messages"] = prompt else: assert isinstance(prompt, str) generate_kwargs["prompt"] = prompt response = await self.api_client.make_request(**generate_kwargs) logger.debug("response in score_explanation_by_activations is %s", response) result = parse_simulation_response(response, self.prompt_format, tokens) logger.debug("result in score_explanation_by_activations is %s", result) return result # TODO(sbills): The current token<tab>activation format can result in improper tokenization. # In particular, if the token is itself a tab, we may get a single "\t\t" token rather than two # "\t" tokens. Consider using a separator that does not appear in any multi-character tokens. def make_simulation_prompt(self, tokens: Sequence[str]) -> Union[str, list[HarmonyMessage]]: """Create a few-shot prompt for predicting neuron activations for the given tokens.""" # TODO(sbills): The prompts in this file are subtly different from the ones in explainer.py. # Consider reconciling them. prompt_builder = PromptBuilder() prompt_builder.add_message( Role.SYSTEM, """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at summary of what the neuron does, and try to predict how it will fire on each token. The activation format is token<tab>activation, activations go from 0 to 10, "unknown" indicates an unknown activation. Most activations will be 0. """, ) few_shot_examples = self.few_shot_example_set.get_examples() for i, example in enumerate(few_shot_examples): prompt_builder.add_message( Role.USER, f"\n\nNeuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " f"{example.explanation}", ) formatted_activation_records = format_activation_records( example.activation_records, calculate_max_activation(example.activation_records), start_indices=example.first_revealed_activation_indices, ) prompt_builder.add_message( Role.ASSISTANT, f"\nActivations: {formatted_activation_records}\n" ) prompt_builder.add_message( Role.USER, f"\n\nNeuron {len(few_shot_examples) + 1}\nExplanation of neuron " f"{len(few_shot_examples) + 1} behavior: {EXPLANATION_PREFIX} " f"{self.explanation.strip()}", ) prompt_builder.add_message( Role.ASSISTANT, f"\nActivations: {format_sequences_for_simulation([tokens])}" ) return prompt_builder.build(self.prompt_format) class ExplanationTokenByTokenSimulator(NeuronSimulator): """ Simulate neuron behavior based on an explanation. Unlike ExplanationNeuronSimulator, this class uses one few-shot prompt per token to calculate expected activations. This is slower. This class gets a one-token completion and calculates an expected value from that token's logprobs. """ def __init__( self, model_name: str, explanation: str, max_concurrent: Optional[int] = 10, few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, prompt_format: PromptFormat = PromptFormat.INSTRUCTION_FOLLOWING, cache: bool = False, ): assert ( few_shot_example_set != FewShotExampleSet.ORIGINAL ), "This simulator doesn't support the ORIGINAL few-shot example set." self.api_client = ApiClient( model_name=model_name, max_concurrent=max_concurrent, cache=cache ) self.explanation = explanation self.few_shot_example_set = few_shot_example_set self.prompt_format = prompt_format async def simulate( self, tokens: Sequence[str], ) -> SequenceSimulation: responses_by_token = await asyncio.gather( *[ self._get_activation_stats_for_single_token(tokens, self.explanation, token_index) for token_index in range(len(tokens)) ] ) expected_values, distribution_values, distribution_probabilities = [], [], [] for response in responses_by_token: activation_logprobs = response["choices"][0]["logprobs"]["top_logprobs"][0] ( norm_probabilities_by_distribution_value, expected_value, ) = compute_predicted_activation_stats_for_token( activation_logprobs, ) distribution_values.append( [float(v) for v in norm_probabilities_by_distribution_value.keys()] ) distribution_probabilities.append( list(norm_probabilities_by_distribution_value.values()) ) expected_values.append(expected_value) result = SequenceSimulation( tokens=list(tokens), # SequenceSimulation expects List type expected_activations=expected_values, activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, distribution_values=distribution_values, distribution_probabilities=distribution_probabilities, ) logger.debug("result in score_explanation_by_activations is %s", result) return result async def _get_activation_stats_for_single_token( self, tokens: Sequence[str], explanation: str, token_index_to_score: int, ) -> dict: prompt = self.make_single_token_simulation_prompt( tokens, explanation, token_index_to_score=token_index_to_score, ) return await self.api_client.make_request( prompt=prompt, max_tokens=1, echo=False, logprobs=15 ) def _add_single_token_simulation_subprompt( self, prompt_builder: PromptBuilder, activation_record: ActivationRecord, neuron_index: int, explanation: str, token_index_to_score: int, end_of_prompt: bool, ) -> None: trimmed_activation_record = ActivationRecord( tokens=activation_record.tokens[: token_index_to_score + 1], activations=activation_record.activations[: token_index_to_score + 1], ) prompt_builder.add_message( Role.USER, f""" Neuron {neuron_index} Explanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} {explanation.strip()} Text: {"".join(trimmed_activation_record.tokens)} Last token in the text: {trimmed_activation_record.tokens[-1]} Last token activation, considering the token in the context in which it appeared in the text: """, ) if not end_of_prompt: normalized_activations = normalize_activations( trimmed_activation_record.activations, calculate_max_activation([activation_record]) ) prompt_builder.add_message( Role.ASSISTANT, str(normalized_activations[-1]) + ("" if end_of_prompt else "\n\n") ) def make_single_token_simulation_prompt( self, tokens: Sequence[str], explanation: str, token_index_to_score: int, ) -> Union[str, list[HarmonyMessage]]: """Make a few-shot prompt for predicting the neuron's activation on a single token.""" assert explanation != "" prompt_builder = PromptBuilder() prompt_builder.add_message( Role.SYSTEM, """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. The activation format is token<tab>activation, and activations range from 0 to 10. Most activations will be 0. """, ) few_shot_examples = self.few_shot_example_set.get_examples() for i, example in enumerate(few_shot_examples): prompt_builder.add_message( Role.USER, f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " f"{example.explanation}\n", ) formatted_activation_records = format_activation_records( example.activation_records, calculate_max_activation(example.activation_records), start_indices=None, ) prompt_builder.add_message( Role.ASSISTANT, f"Activations: {formatted_activation_records}\n\n", ) prompt_builder.add_message( Role.SYSTEM, "Now, we're going predict the activation of a new neuron on a single token, " "following the same rules as the examples above. Activations still range from 0 to 10.", ) single_token_example = self.few_shot_example_set.get_single_token_prediction_example() assert single_token_example.token_index_to_score is not None self._add_single_token_simulation_subprompt( prompt_builder, single_token_example.activation_records[0], len(few_shot_examples) + 1, explanation, token_index_to_score=single_token_example.token_index_to_score, end_of_prompt=False, ) activation_record = ActivationRecord( tokens=list(tokens[: token_index_to_score + 1]), # ActivationRecord expects List type. activations=[0.0] * len(tokens), ) self._add_single_token_simulation_subprompt( prompt_builder, activation_record, len(few_shot_examples) + 2, explanation, token_index_to_score, end_of_prompt=True, ) return prompt_builder.build(self.prompt_format, allow_extra_system_messages=True) def _format_record_for_logprob_free_simulation( activation_record: ActivationRecord, include_activations: bool = False, max_activation: Optional[float] = None, ) -> str: response = "" if include_activations: assert max_activation is not None assert len(activation_record.tokens) == len( activation_record.activations ), f"{len(activation_record.tokens)=}, {len(activation_record.activations)=}" normalized_activations = normalize_activations( activation_record.activations, max_activation=max_activation ) for i, token in enumerate(activation_record.tokens): # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). if include_activations: response += f"{token}\t{normalized_activations[i]}༗\n" else: response += f"{token}\t༗\n" return response def _parse_no_logprobs_completion( completion: str, tokens: Sequence[str], ) -> Sequence[int]: """ Parse a completion into a list of simulated activations. If the model did not faithfully reproduce the token sequence, return a list of 0s. If the model's activation for a token is not an integer betwee 0 and 10, substitute 0. Args: completion: completion from the API tokens: list of tokens as strings in the sequence where the neuron is being simulated """ zero_prediction = [0] * len(tokens) token_lines = completion.strip("\n").split("༗\n") start_line_index = None for i, token_line in enumerate(token_lines): if token_line.startswith(f"{tokens[0]}\t"): start_line_index = i break # If we didn't find the first token, or if the number of lines in the completion doesn't match # the number of tokens, return a list of 0s. if start_line_index is None or len(token_lines) - start_line_index != len(tokens): return zero_prediction predicted_activations = [] for i, token_line in enumerate(token_lines[start_line_index:]): if not token_line.startswith(f"{tokens[i]}\t"): return zero_prediction predicted_activation = token_line.split("\t")[1] if predicted_activation not in VALID_ACTIVATION_TOKENS: predicted_activations.append(0) else: predicted_activations.append(int(predicted_activation)) return predicted_activations class LogprobFreeExplanationTokenSimulator(NeuronSimulator): """ Simulate neuron behavior based on an explanation. Unlike ExplanationNeuronSimulator and ExplanationTokenByTokenSimulator, this class does not rely on logprobs to calculate expected activations. Instead, it uses a few-shot prompt that displays all of the tokens at once, and request that the model repeat the tokens with the activations appended. Sampling is with temperature = 0. Thus, the activations are deterministic. Also, each activation for a token is a function of all the activations that came previously and all of the tokens in the sequence, not just the current and previous tokens. In the case where the model does not faithfully reproduce the token sequence, the simulator will return a response where every predicted activation is 0. Example prompt as follows: Explanation: Explanation 1 Sequence 1 Tokens Without Activations: A\t_ B\t_ C\t_ Sequence 1 Tokens With Activations: A\t4_ B\t10_ C\t0_ Sequence 2 Tokens Without Activations: D\t_ E\t_ F\t_ Sequence 2 Tokens With Activations: D\t3_ E\t6_ F\t9_ Explanation: Explanation 2 Sequence 1 Tokens Without Activations: G\t_ H\t_ I\t_ Sequence 1 Tokens With Activations: <start sampling here> G\t2_ H\t0_ I\t3_ """ def __init__( self, model_name: str, explanation: str, max_concurrent: Optional[int] = 10, few_shot_example_set: FewShotExampleSet = FewShotExampleSet.NEWER, prompt_format: PromptFormat = PromptFormat.HARMONY_V4, cache: bool = False, ): assert ( few_shot_example_set != FewShotExampleSet.ORIGINAL ), "This simulator doesn't support the ORIGINAL few-shot example set." self.api_client = ApiClient( model_name=model_name, max_concurrent=max_concurrent, cache=cache ) self.explanation = explanation self.few_shot_example_set = few_shot_example_set self.prompt_format = prompt_format async def simulate( self, tokens: Sequence[str], ) -> SequenceSimulation: prompt = self._make_simulation_prompt( tokens, self.explanation, ) response = await self.api_client.make_request( prompt=prompt, echo=False, max_tokens=1000 ) assert len(response["choices"]) == 1 choice = response["choices"][0] if self.prompt_format == PromptFormat.HARMONY_V4: completion = choice["message"]["content"] elif self.prompt_format in [PromptFormat.NONE, PromptFormat.INSTRUCTION_FOLLOWING]: completion = choice["text"] else: raise ValueError(f"Unhandled prompt format {self.prompt_format}") predicted_activations = _parse_no_logprobs_completion(completion, tokens) result = SequenceSimulation( activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, expected_activations=predicted_activations, # Since the predicted activation is just a sampled token, we don't have a distribution. distribution_values=[], distribution_probabilities=[], tokens=list(tokens), # SequenceSimulation expects List type ) logger.debug("result in score_explanation_by_activations is %s", result) return result def _make_simulation_prompt( self, tokens: Sequence[str], explanation: str, ) -> Union[str, list[HarmonyMessage]]: """Make a few-shot prompt for predicting the neuron's activations on a sequence.""" assert explanation != "" prompt_builder = PromptBuilder(allow_extra_system_messages=True) prompt_builder.add_message( Role.SYSTEM, """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. The activation format is token<tab>activation, and activations range from 0 to 10. Most activations will be 0. For each sequence, you will see the tokens in the sequence where the activations are left blank. You will print the exact same tokens verbatim, but with the activations filled in according to the explanation. """, ) few_shot_examples = self.few_shot_example_set.get_examples() for i, example in enumerate(few_shot_examples): few_shot_example_max_activation = calculate_max_activation(example.activation_records) prompt_builder.add_message( Role.USER, f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " f"{example.explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{_format_record_for_logprob_free_simulation(example.activation_records[0], include_activations=False)}\n\n" f"Sequence 1 Tokens with Activations:\n", ) prompt_builder.add_message( Role.ASSISTANT, f"{_format_record_for_logprob_free_simulation(example.activation_records[0], include_activations=True, max_activation=few_shot_example_max_activation)}\n\n", ) for record_index, record in enumerate(example.activation_records[1:]): prompt_builder.add_message( Role.USER, f"Sequence {record_index + 2} Tokens without Activations:\n{_format_record_for_logprob_free_simulation(record, include_activations=False)}\n\n" f"Sequence {record_index + 2} Tokens with Activations:\n", ) prompt_builder.add_message( Role.ASSISTANT, f"{_format_record_for_logprob_free_simulation(record, include_activations=True, max_activation=few_shot_example_max_activation)}\n\n", ) neuron_index = len(few_shot_examples) + 1 prompt_builder.add_message( Role.USER, f"Neuron {neuron_index}\nExplanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} " f"{explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{_format_record_for_logprob_free_simulation(ActivationRecord(tokens=tokens, activations=[]), include_activations=False)}\n\n" f"Sequence 1 Tokens with Activations:\n", ) return prompt_builder.build(self.prompt_format)