neuron_explainer/models/model_registry.py (101 lines of code) (raw):

from dataclasses import dataclass import torch from neuron_explainer.activations.derived_scalars import DerivedScalarType from neuron_explainer.models import Transformer from neuron_explainer.models.autoencoder_context import ( AutoencoderConfig, AutoencoderContext, AutoencoderSpec, ) @dataclass(frozen=True) class StandardModelSpec: model_path: str # checkpoint path _MODEL_SPECS: dict[str, StandardModelSpec] = { # GPT-2 series "gpt2-small": StandardModelSpec( model_path="https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/gpt2/small" ), "gpt2-medium": StandardModelSpec( model_path="https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/gpt2/medium" ), "gpt2-large": StandardModelSpec( model_path="https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/gpt2/large" ), "gpt2-xl": StandardModelSpec( model_path="https://openaipublic.blob.core.windows.net/neuron-explainer/subject-models/gpt2/xl" ), } _AUTOENCODER_SPECS: dict[str, dict[str, AutoencoderSpec]] = { "gpt2-small": { # released December 2023 "ae-mlp-post-act-v1": AutoencoderSpec( dst=DerivedScalarType.MLP_POST_ACT, autoencoder_path_by_layer_index={ layer_index: f"https://openaipublic.blob.core.windows.net/sparse-autoencoder/gpt2-small/mlp_post_act/autoencoders/{layer_index}.pt" for layer_index in range(12) }, ), "ae-resid-delta-mlp-v1": AutoencoderSpec( dst=DerivedScalarType.RESID_DELTA_MLP, autoencoder_path_by_layer_index={ layer_index: f"https://openaipublic.blob.core.windows.net/sparse-autoencoder/gpt2-small/resid_delta_mlp/autoencoders/{layer_index}.pt" for layer_index in range(12) }, ), # released March 2024 "ae-mlp-post-act-v4": AutoencoderSpec( dst=DerivedScalarType.MLP_POST_ACT, autoencoder_path_by_layer_index={ layer_index: f"https://openaipublic.blob.core.windows.net/sparse-autoencoder/gpt2-small/mlp_post_act_v4/autoencoders/{layer_index}.pt" for layer_index in range(12) }, ), "ae-resid-delta-mlp-v4": AutoencoderSpec( dst=DerivedScalarType.RESID_DELTA_MLP, autoencoder_path_by_layer_index={ layer_index: f"https://openaipublic.blob.core.windows.net/sparse-autoencoder/gpt2-small/resid_delta_mlp_v4/autoencoders/{layer_index}.pt" for layer_index in range(12) }, ), "ae-resid-delta-attn-v4": AutoencoderSpec( dst=DerivedScalarType.RESID_DELTA_ATTN, autoencoder_path_by_layer_index={ layer_index: f"https://openaipublic.blob.core.windows.net/sparse-autoencoder/gpt2-small/resid_delta_attn_v4/autoencoders/{layer_index}.pt" for layer_index in range(12) }, ), }, } def list_autoencoder_names(model_name: str = "gpt2-small") -> list[str]: return list(_AUTOENCODER_SPECS[model_name].keys()) def get_standard_model_spec(model_name: str) -> StandardModelSpec: return _MODEL_SPECS[model_name] def load_standard_transformer(model_name: str, device: torch.device | None = None) -> Transformer: print(f"Loading standard model {model_name}...") model_spec = get_standard_model_spec(model_name) return load_standard_transformer_from_model_spec(model_spec, device=device) def load_standard_transformer_from_model_spec( model_spec: StandardModelSpec, device: torch.device | None = None ) -> Transformer: return Transformer.load( model_spec.model_path, dtype=torch.float32, device=device, ) def make_autoencoder_context( model_name: str, autoencoder_name: str, device: torch.device, omit_dead_latents: bool = False, ) -> AutoencoderContext: try: autoencoder_spec = _AUTOENCODER_SPECS[model_name][autoencoder_name] except KeyError: raise ValueError( f"No autoencoder spec found for model {model_name} and autoencoder {autoencoder_name}. " f"Available autoencoders for model {model_name} are: {list(_AUTOENCODER_SPECS[model_name].keys())}" ) autoencoder_config = AutoencoderConfig.from_spec(autoencoder_spec) autoencoder_context = AutoencoderContext( autoencoder_config=autoencoder_config, device=device, omit_dead_latents=omit_dead_latents, ) return autoencoder_context