neuron_explainer/activation_server/neuron_datasets.py (101 lines of code) (raw):

""" Library for looking up neuron datasets and their associated metadata by name. """ from neuron_explainer.activations.derived_scalars import DerivedScalarType from neuron_explainer.pydantic import CamelCaseBaseModel, immutable @immutable class NeuronDatasetMetadata(CamelCaseBaseModel): short_name: str """Short name for the dataset, like "gpt2-small".""" derived_scalar_type: str """The type of scalar that the neuron records in this dataset contain, e.g. DerivedScalarType.MLP_POST_ACT.""" user_visible_name: str """Name for humans to read, like "GPT-2 small".""" neuron_dataset_path: str """Path to the neuron dataset generated by collate_activations.""" # Take care when adding new fields to this class. If they aren't optional, existing metadata # files will cause errors when they're read. You can make them required after ensuring that all # metadata files have been updated. NEURON_DATASET_METADATA_REGISTRY = {} def register_neuron_dataset_metadata( short_name: str, derived_scalar_type: str, user_visible_name: str, neuron_dataset_path: str ) -> None: NEURON_DATASET_METADATA_REGISTRY[ (short_name, DerivedScalarType(derived_scalar_type)) ] = NeuronDatasetMetadata( short_name=short_name, derived_scalar_type=derived_scalar_type, user_visible_name=user_visible_name, neuron_dataset_path=neuron_dataset_path, ) register_neuron_dataset_metadata( short_name="gpt2-xl", derived_scalar_type="mlp_post_act", user_visible_name="GPT-2 XL - MLP neurons", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations/", ) register_neuron_dataset_metadata( short_name="gpt2-small", derived_scalar_type="mlp_post_act", user_visible_name="GPT-2 small - MLP neurons", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/collated-activations/", ) register_neuron_dataset_metadata( short_name="gpt2-small", derived_scalar_type="attn_write_norm", user_visible_name="GPT-2 small - Attention write by token pair", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/attn_write_norm/collated-activations-by-token-pair", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-mlp-post-act-v1", derived_scalar_type="mlp_autoencoder_latent", user_visible_name="GPT-2 small - MLP autoencoder (post-act) v1", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v1/collated-activations", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-resid-delta-mlp-v1", derived_scalar_type="mlp_autoencoder_latent", user_visible_name="GPT-2 small - MLP autoencoder (write) v1", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v1/collated-activations", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-mlp-post-act-v4", derived_scalar_type="mlp_autoencoder_latent", user_visible_name="GPT-2 small - MLP autoencoder (post-act) v4", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/mlp_post_act_v4/collated-activations", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-resid-delta-mlp-v4", derived_scalar_type="mlp_autoencoder_latent", user_visible_name="GPT-2 small - MLP autoencoder (write) v4", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_mlp_v4/collated-activations", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-resid-delta-attn-v4", derived_scalar_type="attention_autoencoder_latent", user_visible_name="GPT-2 small - Attention autoencoder (write) v4", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations", ) register_neuron_dataset_metadata( short_name="gpt2-small_ae-resid-delta-attn-v4", derived_scalar_type="flattened_attn_write_to_latent_summed_over_heads", user_visible_name="GPT-2 small - Attention autoencoder (write) v4 by token pair", neuron_dataset_path="https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small/autoencoder_latent/resid_delta_attn_v4/collated-activations-by-token-pair", ) def get_all_neuron_dataset_metadata() -> list[NeuronDatasetMetadata]: return list(NEURON_DATASET_METADATA_REGISTRY.values()) def get_neuron_dataset_metadata_by_short_name_and_dst( short_name: str, dst: DerivedScalarType ) -> NeuronDatasetMetadata: name_and_type = (short_name, dst) metadata = NEURON_DATASET_METADATA_REGISTRY.get(name_and_type) if metadata is None: error_message = f"Could not find collated activation dataset for {name_and_type}. Available datasets are: " error_message += ", ".join( f'("{short_name}", "{dst}")' for (short_name, dst) in NEURON_DATASET_METADATA_REGISTRY ) if short_name.endswith("_undefined"): # This is likely due to the URL not providing the correct autoencoder name if dst in [ DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS, DerivedScalarType.ATTENTION_AUTOENCODER_LATENT, ]: autoencoder_type = "attention " elif dst == DerivedScalarType.MLP_AUTOENCODER_LATENT: autoencoder_type = "mlp " else: autoencoder_type = "" error_message += ( f"\nMay need to specify the {autoencoder_type}autoencoder name in the URL." ) raise Exception(error_message) return metadata