neuron_explainer/activations/activation_records.py (119 lines of code) (raw):
"""Utilities for formatting activation records into prompts."""
import math
from typing import Any, Sequence
from neuron_explainer.activations.activations import ActivationRecord
UNKNOWN_ACTIVATION_STRING = "unknown"
def relu(x: float) -> float:
return max(0.0, x)
def calculate_max_activation(activation_records: Sequence[ActivationRecord]) -> float:
"""Return the maximum activation value of the neuron across all the activation records."""
flattened = [
# Relu is used to assume any values less than 0 are indicating the neuron is in the resting
# state.
max(relu(x) for x in activation_record.activations)
for activation_record in activation_records
]
return max(flattened)
def normalize_activations(activation_record: list[float], max_activation: float) -> list[int]:
"""Convert raw neuron activations to integers on the range [0, 10]."""
if max_activation <= 0:
# raise ValueError(f"max_activation must be positive, got {max_activation}")
return [
0 for x in activation_record
] # commented the above line to allow GPT2 datasets to render (Dan)
# Relu is used to assume any values less than 0 are indicating the neuron is in the resting
# state.
return [min(10, math.floor(10 * relu(x) / max_activation)) for x in activation_record]
def normalize_activations_symmetric(
activation_record: list[float], max_activation: float
) -> list[int]:
"""Convert raw neuron activations to integers on the range [-10, 10]."""
max_abs_activation = (
max_activation # clients expect kwarg "max_activation", so leaving this for now
)
if max_abs_activation == 0.0:
# raise ValueError(f"max_activation must be positive, got {max_activation}")
return [0 for x in activation_record]
# Unlike normalize_activations, this function doesn't apply relu, since we want to show negative
# activations as well.
return [max(min(10, math.trunc(10 * x / max_abs_activation)), -10) for x in activation_record]
def truncate_negative_activations(activation_record: ActivationRecord) -> ActivationRecord:
"""Truncate activations to 0 if they are negative."""
return ActivationRecord(
tokens=activation_record.tokens,
activations=[max(0, x) for x in activation_record.activations],
)
def truncate_negative_activations_list(
activation_records: Sequence[ActivationRecord],
) -> list[ActivationRecord]:
"""Truncate activations to 0 if they are negative."""
return [truncate_negative_activations(x) for x in activation_records]
def _format_activation_record(
activation_record: ActivationRecord,
max_activation: float,
omit_zeros: bool,
hide_activations: bool = False,
start_index: int = 0,
) -> str:
"""Format neuron activations into a string, suitable for use in prompts."""
tokens = activation_record.tokens
normalized_activations = normalize_activations(activation_record.activations, max_activation)
if omit_zeros:
assert (not hide_activations) and start_index == 0, "Can't hide activations and omit zeros"
tokens = [
token for token, activation in zip(tokens, normalized_activations) if activation > 0
]
normalized_activations = [x for x in normalized_activations if x > 0]
entries = []
assert len(tokens) == len(normalized_activations)
for index, token, activation in zip(range(len(tokens)), tokens, normalized_activations):
activation_string = str(int(activation))
if hide_activations or index < start_index:
activation_string = UNKNOWN_ACTIVATION_STRING
entries.append(f"{token}\t{activation_string}")
return "\n".join(entries)
def format_activation_records(
activation_records: Sequence[ActivationRecord],
max_activation: float,
*,
omit_zeros: bool = False,
start_indices: list[int] | None = None,
hide_activations: bool = False,
) -> str:
"""Format a list of activation records into a string."""
return (
"\n<start>\n"
+ "\n<end>\n<start>\n".join(
[
_format_activation_record(
activation_record,
max_activation,
omit_zeros=omit_zeros,
hide_activations=hide_activations,
start_index=0 if start_indices is None else start_indices[i],
)
for i, activation_record in enumerate(activation_records)
]
)
+ "\n<end>\n"
)
def _format_tokens_for_simulation(tokens: Sequence[str]) -> str:
"""
Format tokens into a string with each token marked as having an "unknown" activation, suitable
for use in prompts.
"""
entries = []
for token in tokens:
entries.append(f"{token}\t{UNKNOWN_ACTIVATION_STRING}")
return "\n".join(entries)
def format_sequences_for_simulation(
all_tokens: Sequence[Sequence[str]],
) -> str:
"""
Format a list of lists of tokens into a string with each token marked as having an "unknown"
activation, suitable for use in prompts.
"""
return (
"\n<start>\n"
+ "\n<end>\n<start>\n".join(
[_format_tokens_for_simulation(tokens) for tokens in all_tokens]
)
+ "\n<end>\n"
)
def non_zero_activation_proportion(
activation_records: Sequence[ActivationRecord], max_activation: float
) -> float:
"""Return the proportion of activation values that aren't zero."""
total_activations_count = sum(
[len(activation_record.activations) for activation_record in activation_records]
)
normalized_activations = [
normalize_activations(activation_record.activations, max_activation)
for activation_record in activation_records
]
non_zero_activations_count = sum(
[len([x for x in activations if x != 0]) for activations in normalized_activations]
)
return non_zero_activations_count / total_activations_count
def get_attribute_or_key(activation_record: Any, attribute_name: str) -> Any:
if isinstance(activation_record, dict):
assert attribute_name in activation_record, f"{attribute_name} not in activation_record"
return activation_record[attribute_name]
else:
assert hasattr(activation_record, attribute_name)
return getattr(activation_record, attribute_name)