neuron_explainer/models/autoencoder_context.py (218 lines of code) (raw):
import os
from dataclasses import dataclass, field
from typing import Union
import torch
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.file_utils import copy_to_local_cache, file_exists
from neuron_explainer.models import Autoencoder
from neuron_explainer.models.model_component_registry import Dimension, LayerIndex, NodeType
@dataclass(frozen=True)
class AutoencoderSpec:
"""Parameters used in the construction of an AutoencoderConfig object. Seperate so we don't need to validate when constructed"""
dst: DerivedScalarType
autoencoder_path_by_layer_index: dict[LayerIndex, str]
@dataclass(frozen=True)
class AutoencoderConfig:
"""
This class specifies a set of autoencoders to load from disk, for one or more layer indices.
The activation location type indicates the type of activation that the autoencoder was trained
on, and that will be fed into the autoencoder.
"""
dst: DerivedScalarType
autoencoder_path_by_layer_index: dict[LayerIndex, str]
def __post_init__(self) -> None:
assert len(self.autoencoder_path_by_layer_index) > 0
if len(self.autoencoder_path_by_layer_index) > 1:
assert (
None not in self.autoencoder_path_by_layer_index.keys()
), "layer_indices must be [None], or a list of int layer indices"
@classmethod
def from_spec(cls, params: AutoencoderSpec) -> "AutoencoderConfig":
return cls(
dst=params.dst,
autoencoder_path_by_layer_index=params.autoencoder_path_by_layer_index,
)
@dataclass(frozen=True)
class AutoencoderContext:
autoencoder_config: AutoencoderConfig
device: torch.device
_cached_autoencoders_by_path: dict[str, Autoencoder] = field(default_factory=dict)
omit_dead_latents: bool = False
"""
Omit dead latents to save memory. Only happens if self.warmup() is called. Because we omit the
same number of latents from all autoencoders, we can only omit the smallest number of dead
latents among all autoencoders.
"""
@property
def num_autoencoder_directions(self) -> int:
"""Note that this property might change after warmup() is called, if omit_dead_latents is True."""
if len(self._cached_autoencoders_by_path) == 0:
raise ValueError(
"num_autoencoder_directions is not populated yet. Call warmup() first."
)
else:
# all autoencoders have the same number of directions, so we can just check one
first_autoencoder = next(iter(self._cached_autoencoders_by_path.values()))
return first_autoencoder.latent_bias.shape[0]
@property
def _min_n_dead_latents(self) -> int:
return min(
count_dead_latents(autoencoder)
for autoencoder in self._cached_autoencoders_by_path.values()
)
@property
def dst(self) -> DerivedScalarType:
return self.autoencoder_config.dst
@property
def layer_indices(self) -> set[LayerIndex]:
return set(self.autoencoder_config.autoencoder_path_by_layer_index.keys())
def get_autoencoder(self, layer_index: LayerIndex) -> Autoencoder:
autoencoder_azure_path = self.autoencoder_config.autoencoder_path_by_layer_index.get(
layer_index
)
if autoencoder_azure_path is None:
raise ValueError(f"No autoencoder path for layer_index {layer_index}")
else:
if autoencoder_azure_path in self._cached_autoencoders_by_path:
autoencoder = self._cached_autoencoders_by_path[autoencoder_azure_path]
else:
# Check if the autoencoder is cached on disk
disk_cache_path = os.path.join(
"/tmp", autoencoder_azure_path.replace("https://", "")
)
if file_exists(disk_cache_path):
print(f"Loading autoencoder from disk cache: {disk_cache_path}")
else:
print(f"Reading autoencoder from blob storage: {autoencoder_azure_path}")
copy_to_local_cache(autoencoder_azure_path, disk_cache_path)
state_dict = torch.load(disk_cache_path, map_location=self.device)
# released autoencoders are saved as a dict for better compatibility
assert isinstance(state_dict, dict)
autoencoder = Autoencoder.from_state_dict(state_dict, strict=False).to(self.device)
self._cached_autoencoders_by_path[autoencoder_azure_path] = autoencoder
# freeze the autoencoder
for p in autoencoder.parameters():
p.requires_grad = False
return autoencoder
def warmup(self) -> None:
"""Load all autoencoders into memory."""
for layer_index in self.layer_indices:
self.get_autoencoder(layer_index)
# num_autoencoder_directions is always populated after warmup
n_latents = self.num_autoencoder_directions
if self.omit_dead_latents:
# drop the dead latents to save memory, but keep the same number of directions for all autoencoders
if self._min_n_dead_latents > 0:
print(f"Omitting {self._min_n_dead_latents} dead latents from all autoencoders")
n_latents_to_keep = n_latents - self._min_n_dead_latents
for key, autoencoder in self._cached_autoencoders_by_path.items():
self._cached_autoencoders_by_path[key] = omit_least_active_latents(
autoencoder, n_latents_to_keep=n_latents_to_keep
)
def get_parameterized_dimension_sizes(self) -> dict[Dimension, int]:
"""A dictionary specifying the size of the parameterized dimensions; for convenient use with ScalarDerivers"""
return {
Dimension.AUTOENCODER_LATENTS: self.num_autoencoder_directions,
}
@property
def autoencoder_node_type(self) -> NodeType | None:
return _autoencoder_node_type_by_input_dst.get(self.dst)
_autoencoder_node_type_by_input_dst = {
# add more mappings as needed
DerivedScalarType.MLP_POST_ACT: NodeType.MLP_AUTOENCODER_LATENT,
DerivedScalarType.RESID_DELTA_MLP_FROM_MLP_POST_ACT: NodeType.MLP_AUTOENCODER_LATENT,
DerivedScalarType.RESID_DELTA_MLP: NodeType.MLP_AUTOENCODER_LATENT,
DerivedScalarType.RESID_DELTA_ATTN: NodeType.ATTENTION_AUTOENCODER_LATENT,
DerivedScalarType.ATTN_WRITE: NodeType.ATTENTION_AUTOENCODER_LATENT,
}
@dataclass(frozen=True)
class MultiAutoencoderContext:
autoencoder_context_by_node_type: dict[NodeType, AutoencoderContext]
@classmethod
def from_context_or_multi_context(
cls,
input: Union[AutoencoderContext, "MultiAutoencoderContext", None],
) -> Union["MultiAutoencoderContext", None]:
if isinstance(input, AutoencoderContext):
return cls.from_autoencoder_context_list([input])
elif input is None:
return None
else:
return input
@classmethod
def from_autoencoder_context_list(
cls, autoencoder_context_list: list[AutoencoderContext]
) -> "MultiAutoencoderContext":
# check if there are duplicate node types
node_types = [
_autoencoder_node_type_by_input_dst[autoencoder_context.dst]
for autoencoder_context in autoencoder_context_list
]
if len(node_types) != len(set(node_types)):
raise ValueError(f"Cannot load two autoencoders with the same node type ({node_types})")
return cls(
autoencoder_context_by_node_type={
_autoencoder_node_type_by_input_dst[autoencoder_context.dst]: autoencoder_context
for autoencoder_context in autoencoder_context_list
}
)
def get_autoencoder_context(
self, node_type: NodeType | None = None
) -> AutoencoderContext | None:
if node_type is None or node_type == NodeType.AUTOENCODER_LATENT: # handle default case
return self.get_single_autoencoder_context()
else:
return self.autoencoder_context_by_node_type.get(node_type, None)
@property
def has_single_autoencoder_context(self) -> bool:
return len(self.autoencoder_context_by_node_type) == 1
def get_single_autoencoder_context(self) -> AutoencoderContext:
assert self.has_single_autoencoder_context
return next(iter(self.autoencoder_context_by_node_type.values()))
def get_autoencoder(
self, layer_index: LayerIndex, node_type: NodeType | None = None
) -> Autoencoder:
autoencoder_context = self.get_autoencoder_context(node_type)
assert autoencoder_context is not None
return autoencoder_context.get_autoencoder(layer_index)
def warmup(self) -> None:
"""Load all autoencoders into memory."""
for node_type, autoencoder_context in self.autoencoder_context_by_node_type.items():
print(f"Warming up autoencoder {node_type}")
autoencoder_context.warmup()
def get_decoder_weight(autoencoder: Autoencoder) -> torch.Tensor:
return autoencoder.decoder.weight.T # shape (n_latents, d_ff)
def get_autoencoder_output_weight_by_layer_index(
autoencoder_context: AutoencoderContext,
) -> dict[LayerIndex, torch.Tensor]:
return {
layer_index: get_decoder_weight(
autoencoder_context.get_autoencoder(layer_index)
) # shape (n_latents, d_ff)
for layer_index in autoencoder_context.layer_indices
}
ACTIVATION_FREQUENCY_THRESHOLD_FOR_DEAD_LATENTS = 1e-8
def count_dead_latents(autoencoder: Autoencoder) -> int:
if hasattr(autoencoder, "latents_activation_frequency"):
if torch.all(autoencoder.latents_activation_frequency == 0):
raise ValueError("latents_activation_frequency is all zeros, all latents are dead.")
dead_latents_mask = (
autoencoder.latents_activation_frequency
< ACTIVATION_FREQUENCY_THRESHOLD_FOR_DEAD_LATENTS
)
num_dead_latents = int(dead_latents_mask.sum().item())
return num_dead_latents
else:
return 0
def omit_least_active_latents(
autoencoder: Autoencoder,
n_latents_to_keep: int,
# if preserve_indices=True, ignore the stored activation frequencies, and keep the first indices.
# this is to preserve latent indices compared to the original autoencoder.
preserve_indices: bool = True,
) -> Autoencoder:
n_latents_original = int(autoencoder.latent_bias.shape[0])
if n_latents_to_keep >= n_latents_original:
return autoencoder
device: torch.device = autoencoder.encoder.weight.device
# create the dead latent mask (True for live latents, False for dead latents)
mask = torch.ones(n_latents_original, dtype=torch.bool, device=device)
if preserve_indices or not hasattr(autoencoder, "latents_activation_frequency"):
# drop the last latents
mask[n_latents_to_keep:] = 0
else:
# drop the least active latents
order = torch.argsort(autoencoder.latents_activation_frequency, descending=True)
mask[order[n_latents_to_keep:]] = 0
# apply the mask to a new autoencoder
n_latents = int(mask.sum().item())
d_model = autoencoder.pre_bias.shape[0]
new_autoencoder = Autoencoder(n_latents, d_model).to(device)
new_autoencoder.encoder.weight.data = autoencoder.encoder.weight[mask, :].clone()
new_autoencoder.decoder.weight.data = autoencoder.decoder.weight[:, mask].clone()
new_autoencoder.latent_bias.data = autoencoder.latent_bias[mask].clone()
new_autoencoder.stats_last_nonzero.data = autoencoder.stats_last_nonzero[mask].clone()
if hasattr(autoencoder, "latents_mean_square"):
new_autoencoder.register_buffer(
"latents_mean_square", torch.zeros(n_latents, dtype=torch.float)
)
new_autoencoder.latents_mean_square.data = autoencoder.latents_mean_square[mask].clone()
if hasattr(autoencoder, "latents_activation_frequency"):
new_autoencoder.register_buffer(
"latents_activation_frequency", torch.ones(n_latents, dtype=torch.float)
)
new_autoencoder.latents_activation_frequency.data = (
autoencoder.latents_activation_frequency[mask].clone()
)
return new_autoencoder