neuron_explainer/activation_server/neuron_datasets.py (101 lines of code) (raw):
"""
Library for looking up neuron datasets and their associated metadata by name.
"""
from neuron_explainer.activations.derived_scalars import DerivedScalarType
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable
@immutable
class NeuronDatasetMetadata(CamelCaseBaseModel):
short_name: str
"""Short name for the dataset, like "gpt2-small"."""
derived_scalar_type: str
"""The type of scalar that the neuron records in this dataset contain, e.g. DerivedScalarType.MLP_POST_ACT."""
user_visible_name: str
"""Name for humans to read, like "GPT-2 small"."""
neuron_dataset_path: str
"""Path to the neuron dataset generated by collate_activations."""
# Take care when adding new fields to this class. If they aren't optional, existing metadata
# files will cause errors when they're read. You can make them required after ensuring that all
# metadata files have been updated.
NEURON_DATASET_METADATA_REGISTRY = {}
def register_neuron_dataset_metadata(
short_name: str, derived_scalar_type: str, user_visible_name: str, neuron_dataset_path: str
) -> None:
NEURON_DATASET_METADATA_REGISTRY[
(short_name, DerivedScalarType(derived_scalar_type))
] = NeuronDatasetMetadata(
short_name=short_name,
derived_scalar_type=derived_scalar_type,
user_visible_name=user_visible_name,
neuron_dataset_path=neuron_dataset_path,
)
register_neuron_dataset_metadata(
short_name="gpt2-xl",
derived_scalar_type="mlp_post_act",
user_visible_name="GPT-2 XL - MLP neurons",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations/",
)
register_neuron_dataset_metadata(
short_name="gpt2-small",
derived_scalar_type="mlp_post_act",
user_visible_name="GPT-2 small - MLP neurons",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/collated-activations/",
)
register_neuron_dataset_metadata(
short_name="gpt2-small",
derived_scalar_type="attn_write_norm",
user_visible_name="GPT-2 small - Attention write by token pair",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/attn_write_norm/collated-activations-by-token-pair",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-mlp-post-act-v1",
derived_scalar_type="mlp_autoencoder_latent",
user_visible_name="GPT-2 small - MLP autoencoder (post-act) v1",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v1/collated-activations",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-resid-delta-mlp-v1",
derived_scalar_type="mlp_autoencoder_latent",
user_visible_name="GPT-2 small - MLP autoencoder (write) v1",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v1/collated-activations",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-mlp-post-act-v4",
derived_scalar_type="mlp_autoencoder_latent",
user_visible_name="GPT-2 small - MLP autoencoder (post-act) v4",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v4/collated-activations",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-resid-delta-mlp-v4",
derived_scalar_type="mlp_autoencoder_latent",
user_visible_name="GPT-2 small - MLP autoencoder (write) v4",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v4/collated-activations",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-resid-delta-attn-v4",
derived_scalar_type="attention_autoencoder_latent",
user_visible_name="GPT-2 small - Attention autoencoder (write) v4",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations",
)
register_neuron_dataset_metadata(
short_name="gpt2-small_ae-resid-delta-attn-v4",
derived_scalar_type="flattened_attn_write_to_latent_summed_over_heads",
user_visible_name="GPT-2 small - Attention autoencoder (write) v4 by token pair",
neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations-by-token-pair",
)
def get_all_neuron_dataset_metadata() -> list[NeuronDatasetMetadata]:
return list(NEURON_DATASET_METADATA_REGISTRY.values())
def get_neuron_dataset_metadata_by_short_name_and_dst(
short_name: str, dst: DerivedScalarType
) -> NeuronDatasetMetadata:
name_and_type = (short_name, dst)
metadata = NEURON_DATASET_METADATA_REGISTRY.get(name_and_type)
if metadata is None:
error_message = f"Could not find collated activation dataset for {name_and_type}. Available datasets are: "
error_message += ", ".join(
f'("{short_name}", "{dst}")' for (short_name, dst) in NEURON_DATASET_METADATA_REGISTRY
)
if short_name.endswith("_undefined"):
# This is likely due to the URL not providing the correct autoencoder name
if dst in [
DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS,
DerivedScalarType.ATTENTION_AUTOENCODER_LATENT,
]:
autoencoder_type = "attention "
elif dst == DerivedScalarType.MLP_AUTOENCODER_LATENT:
autoencoder_type = "mlp "
else:
autoencoder_type = ""
error_message += (
f"\nMay need to specify the {autoencoder_type}autoencoder name in the URL."
)
raise Exception(error_message)
return metadata