neuron-explainer/neuron_explainer/activations/activation_records.py (92 lines of code) (raw):

"""Utilities for formatting activation records into prompts.""" import math from typing import Optional, Sequence from neuron_explainer.activations.activations import ActivationRecord UNKNOWN_ACTIVATION_STRING = "unknown" def relu(x: float) -> float: return max(0.0, x) def calculate_max_activation(activation_records: Sequence[ActivationRecord]) -> float: """Return the maximum activation value of the neuron across all the activation records.""" flattened = [ # Relu is used to assume any values less than 0 are indicating the neuron is in the resting # state. This is a simplifying assumption that works with relu/gelu. max(relu(x) for x in activation_record.activations) for activation_record in activation_records ] return max(flattened) def normalize_activations(activation_record: list[float], max_activation: float) -> list[int]: """Convert raw neuron activations to integers on the range [0, 10].""" if max_activation <= 0: return [0 for x in activation_record] # Relu is used to assume any values less than 0 are indicating the neuron is in the resting # state. This is a simplifying assumption that works with relu/gelu. return [min(10, math.floor(10 * relu(x) / max_activation)) for x in activation_record] def _format_activation_record( activation_record: ActivationRecord, max_activation: float, omit_zeros: bool, hide_activations: bool = False, start_index: int = 0, ) -> str: """Format neuron activations into a string, suitable for use in prompts.""" tokens = activation_record.tokens normalized_activations = normalize_activations(activation_record.activations, max_activation) if omit_zeros: assert (not hide_activations) and start_index == 0, "Can't hide activations and omit zeros" tokens = [ token for token, activation in zip(tokens, normalized_activations) if activation > 0 ] normalized_activations = [x for x in normalized_activations if x > 0] entries = [] assert len(tokens) == len(normalized_activations) for index, token, activation in zip(range(len(tokens)), tokens, normalized_activations): activation_string = str(int(activation)) if hide_activations or index < start_index: activation_string = UNKNOWN_ACTIVATION_STRING entries.append(f"{token}\t{activation_string}") return "\n".join(entries) def format_activation_records( activation_records: Sequence[ActivationRecord], max_activation: float, *, omit_zeros: bool = False, start_indices: Optional[list[int]] = None, hide_activations: bool = False, ) -> str: """Format a list of activation records into a string.""" return ( "\n<start>\n" + "\n<end>\n<start>\n".join( [ _format_activation_record( activation_record, max_activation, omit_zeros=omit_zeros, hide_activations=hide_activations, start_index=0 if start_indices is None else start_indices[i], ) for i, activation_record in enumerate(activation_records) ] ) + "\n<end>\n" ) def _format_tokens_for_simulation(tokens: Sequence[str]) -> str: """ Format tokens into a string with each token marked as having an "unknown" activation, suitable for use in prompts. """ entries = [] for token in tokens: entries.append(f"{token}\t{UNKNOWN_ACTIVATION_STRING}") return "\n".join(entries) def format_sequences_for_simulation( all_tokens: Sequence[Sequence[str]], ) -> str: """ Format a list of lists of tokens into a string with each token marked as having an "unknown" activation, suitable for use in prompts. """ return ( "\n<start>\n" + "\n<end>\n<start>\n".join( [_format_tokens_for_simulation(tokens) for tokens in all_tokens] ) + "\n<end>\n" ) def non_zero_activation_proportion( activation_records: Sequence[ActivationRecord], max_activation: float ) -> float: """Return the proportion of activation values that aren't zero.""" total_activations_count = sum( [len(activation_record.activations) for activation_record in activation_records] ) normalized_activations = [ normalize_activations(activation_record.activations, max_activation) for activation_record in activation_records ] non_zero_activations_count = sum( [len([x for x in activations if x != 0]) for activations in normalized_activations] ) return non_zero_activations_count / total_activations_count