neuron_explainer/activations/derived_scalars/indexing.py (422 lines of code) (raw):
"""
This file contains classes for referring to individual nodes (e.g. attention heads), activations
(e.g. attention post-softmax), or derived scalars (e.g. attention head write norm) from a forward
pass. DerivedScalarIndex can be used to index into a DerivedScalarStore.
These classes have a parallel structure to each other. One node index can be associated with
multiple activation indices and derived scalar indices. Derived scalar indices can be associated
with more types of scalars that aren't instantiated as 'activations' in the forward pass as
implemented.
Mirrored versions of these classes are used to refer to the same objects, but in a way that can be
transmitted via pydantic response and request data types for communication with a server. Changes
applied to mirrored dataclasses must be applied also to their unmirrored versions, and vice versa.
"""
import dataclasses
from dataclasses import dataclass
from enum import Enum, unique
from typing import Any, Literal, Union
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
Dimension,
LayerIndex,
NodeType,
PassType,
)
from neuron_explainer.pydantic import CamelCaseBaseModel, HashableBaseModel, immutable
DETACH_LAYER_NORM_SCALE = (
True # this sets default behavior for whether to detach layer norm scale everywhere
# TODO: if all goes well, have this be hard-coded to True, and remove the plumbing
)
@dataclass(frozen=True)
class DerivedScalarIndex:
"""
Indexes into a DerivedScalarStore and returns a tensor of activations specified by indices.
"""
dst: DerivedScalarType
tensor_indices: tuple[
int | None, ...
] # the indices of the activation tensor (not including layer_index)
# elements of indices correspond to the elements of
# scalar_deriver.shape_of_activation_per_token_spec
# e.g. MLP activations might have shape (n_tokens, n_neurons).
# an element of indices is None -> apply slice(None) for that dimension
layer_index: LayerIndex # the layer_index of the activation, if applicable
pass_type: PassType
@property
def tensor_index_by_dim(self) -> dict[Dimension, int | None]:
tensor_indices_list = list(self.tensor_indices)
assert len(tensor_indices_list) <= len(self.dst.shape_spec_per_token_sequence), (
f"Too many tensor indices {tensor_indices_list} for "
f"{self.dst.shape_spec_per_token_sequence=}"
)
tensor_indices_list.extend(
[None] * (len(self.dst.shape_spec_per_token_sequence) - len(self.tensor_indices))
)
return dict(zip(self.dst.shape_spec_per_token_sequence, tensor_indices_list))
@classmethod
def from_node_index(
cls,
node_index: "NodeIndex | MirroredNodeIndex",
dst: DerivedScalarType,
) -> "DerivedScalarIndex":
# with the extra information of what dst is desired (subject to the constraint
# that it must share the same node_type), we can convert a NodeIndex to a DerivedScalarIndex
assert (
node_index.node_type == dst.node_type
), f"Node type does not match with the derived scalar type: {node_index.node_type=}, {dst=}"
return cls(
dst=dst,
layer_index=node_index.layer_index,
tensor_indices=node_index.tensor_indices,
pass_type=node_index.pass_type,
)
@immutable
class MirroredDerivedScalarIndex(HashableBaseModel):
dst: DerivedScalarType
tensor_indices: tuple[int | None, ...]
layer_index: LayerIndex
pass_type: PassType
@classmethod
def from_ds_index(cls, ds_index: DerivedScalarIndex) -> "MirroredDerivedScalarIndex":
return cls(
dst=ds_index.dst,
layer_index=ds_index.layer_index,
tensor_indices=ds_index.tensor_indices,
pass_type=ds_index.pass_type,
)
def to_ds_index(self) -> DerivedScalarIndex:
return DerivedScalarIndex(
dst=self.dst,
layer_index=self.layer_index,
tensor_indices=self.tensor_indices,
pass_type=self.pass_type,
)
AllOrOneIndex = Union[int, Literal["All"]]
AllOrOneIndices = tuple[AllOrOneIndex, ...]
@dataclass(frozen=True)
class ActivationIndex:
"""
This is parallel to DerivedScalarIndex, but specifically for ActivationLocationType's, not for more general DerivedScalarType's.
"""
activation_location_type: ActivationLocationType
tensor_indices: AllOrOneIndices
layer_index: LayerIndex
pass_type: PassType
@property
def tensor_index_by_dim(self) -> dict[Dimension, AllOrOneIndex]:
# copied from DerivedScalarIndex; TODO: ActivationIndex and DerivedScalarIndex inherit from a shared base class,
# and perhaps likewise with DerivedScalarType and ActivationLocationType?
tensor_indices_list = list(self.tensor_indices)
assert len(tensor_indices_list) <= len(
self.activation_location_type.shape_spec_per_token_sequence
), (
f"Too many tensor indices {tensor_indices_list} for "
f"{self.activation_location_type.shape_spec_per_token_sequence=}"
)
tensor_indices_list.extend(
["All"]
* (
len(self.activation_location_type.shape_spec_per_token_sequence)
- len(self.tensor_indices)
)
)
assert len(tensor_indices_list) == len(
self.activation_location_type.shape_spec_per_token_sequence
)
return dict(
zip(
self.activation_location_type.shape_spec_per_token_sequence,
tensor_indices_list,
)
)
@classmethod
def from_node_index(
cls,
node_index: "NodeIndex | MirroredNodeIndex",
activation_location_type: ActivationLocationType,
) -> "ActivationIndex":
# with the extra information of what activation_location_type is desired (subject to the constraint
# that it must share the same node_type), we can convert a NodeIndex to an ActivationIndex
assert (
node_index.node_type == activation_location_type.node_type
), f"Node type does not match with the derived scalar type: {node_index.node_type=}, {activation_location_type=}"
return cls(
activation_location_type=activation_location_type,
layer_index=node_index.layer_index,
tensor_indices=make_all_or_one_from_tensor_indices(node_index.tensor_indices),
pass_type=node_index.pass_type,
)
@property
def ndim(self) -> int:
return compute_indexed_tensor_ndim(
activation_location_type=self.activation_location_type,
tensor_indices=self.tensor_indices,
)
def with_updates(self, **kwargs: Any) -> "ActivationIndex":
"""Given new values for fields of this ActivationIndex, return a new ActivationIndex instance with those
fields updated"""
return dataclasses.replace(self, **kwargs)
def make_all_or_one_from_tensor_indices(tensor_indices: tuple[int | None, ...]) -> AllOrOneIndices:
return tuple("All" if tensor_index is None else tensor_index for tensor_index in tensor_indices)
def make_tensor_indices_from_all_or_one_indices(
all_or_one_indices: AllOrOneIndices,
) -> tuple[int | None, ...]:
return tuple(
None if all_or_one_index == "All" else all_or_one_index
for all_or_one_index in all_or_one_indices
)
def compute_indexed_tensor_ndim(
activation_location_type: ActivationLocationType,
tensor_indices: AllOrOneIndices | tuple[int | None, ...],
) -> int:
"""Returns the dimensionality of a tensor of the given ActivationLocationType after being indexed by tensor_indices.
int dimensions are removed from the resulting tensor."""
ndim = activation_location_type.ndim_per_token_sequence - len(
[tensor_index for tensor_index in tensor_indices if tensor_index not in {"All", None}]
)
assert ndim >= 0
return ndim
def make_python_slice_from_tensor_indices(
tensor_indices: tuple[int | None, ...]
) -> tuple[slice | int, ...]:
return make_python_slice_from_all_or_one_indices(
make_all_or_one_from_tensor_indices(tensor_indices)
)
def make_python_slice_from_all_or_one_indices(
all_or_one_indices: AllOrOneIndices,
) -> tuple[slice | int, ...]:
return tuple(
slice(None) if all_or_one_index == "All" else all_or_one_index
for all_or_one_index in all_or_one_indices
)
@immutable
class MirroredActivationIndex(HashableBaseModel):
activation_location_type: ActivationLocationType
tensor_indices: AllOrOneIndices
layer_index: LayerIndex
pass_type: PassType
@classmethod
def from_activation_index(cls, activation_index: ActivationIndex) -> "MirroredActivationIndex":
return cls(
activation_location_type=activation_index.activation_location_type,
layer_index=activation_index.layer_index,
tensor_indices=activation_index.tensor_indices,
pass_type=activation_index.pass_type,
)
def to_activation_index(self) -> ActivationIndex:
return ActivationIndex(
activation_location_type=self.activation_location_type,
layer_index=self.layer_index,
tensor_indices=self.tensor_indices,
pass_type=self.pass_type,
)
@dataclass(frozen=True)
class NodeIndex:
"""
This is parallel to DerivedScalarIndex, but refers to the NodeType associated with a
DerivedScalarType, rather than the DerivedScalarType itself. This is for situations in
which multiple derived scalars are computed for the same node.
"""
node_type: NodeType
tensor_indices: tuple[int | None, ...]
layer_index: LayerIndex
pass_type: PassType
@classmethod
def from_ds_index(
cls,
ds_index: DerivedScalarIndex,
) -> "NodeIndex":
return cls(
node_type=ds_index.dst.node_type,
layer_index=ds_index.layer_index,
tensor_indices=ds_index.tensor_indices,
pass_type=ds_index.pass_type,
)
@classmethod
def from_activation_index(
cls,
activation_index: ActivationIndex,
) -> "NodeIndex":
return cls(
node_type=activation_index.activation_location_type.node_type,
layer_index=activation_index.layer_index,
tensor_indices=make_tensor_indices_from_all_or_one_indices(
activation_index.tensor_indices
),
pass_type=activation_index.pass_type,
)
def with_updates(self, **kwargs: Any) -> "NodeIndex":
"""Given new values for fields of this NodeIndex, return a new NodeIndex instance with those
fields updated"""
return dataclasses.replace(self, **kwargs)
@property
def ndim(self) -> int:
match self.node_type:
case NodeType.ATTENTION_HEAD:
reference_activation_location_type = ActivationLocationType.ATTN_QK_PROBS
case NodeType.MLP_NEURON:
reference_activation_location_type = ActivationLocationType.MLP_POST_ACT
case NodeType.AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_AUTOENCODER_LATENT
)
case NodeType.MLP_AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT
)
case NodeType.ATTENTION_AUTOENCODER_LATENT:
reference_activation_location_type = (
ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT
)
case NodeType.RESIDUAL_STREAM_CHANNEL:
reference_activation_location_type = ActivationLocationType.RESID_POST_MLP
case _:
raise NotImplementedError(f"Node type {self.node_type} not supported")
return compute_indexed_tensor_ndim(
activation_location_type=reference_activation_location_type,
tensor_indices=self.tensor_indices,
)
def to_subnode_index(self, q_k_or_v: ActivationLocationType) -> "AttnSubNodeIndex":
assert (
self.node_type == NodeType.ATTENTION_HEAD
), f"Node type {self.node_type} is not NodeType.ATTENTION_HEAD"
return AttnSubNodeIndex(
node_type=self.node_type,
layer_index=self.layer_index,
tensor_indices=self.tensor_indices,
pass_type=self.pass_type,
q_k_or_v=q_k_or_v,
)
@immutable
class MirroredNodeIndex(HashableBaseModel):
"""This class mirrors the fields of NodeIndex without default values."""
node_type: NodeType
tensor_indices: tuple[int | None, ...]
layer_index: LayerIndex
pass_type: PassType
@classmethod
def from_node_index(cls, node_index: NodeIndex) -> "MirroredNodeIndex":
"""
Note that this conversion may lose information, specifically if the if the NodeIndex
is an instance of a subclass of NodeIndex such as AttnSubNodeIndex.
"""
return cls(
node_type=node_index.node_type,
layer_index=node_index.layer_index,
tensor_indices=node_index.tensor_indices,
pass_type=node_index.pass_type,
)
def to_node_index(self) -> NodeIndex:
return NodeIndex(
node_type=self.node_type,
layer_index=self.layer_index,
tensor_indices=self.tensor_indices,
pass_type=self.pass_type,
)
@dataclass(frozen=True)
class AttnSubNodeIndex(NodeIndex):
"""A NodeIndex that contains an extra piece of metadata, q_k_or_v,
which specifies whether the input to an attention head node should
be restricted to the portion going through the query, key, or value"""
q_k_or_v: ActivationLocationType
def __post_init__(self) -> None:
assert (
self.node_type == NodeType.ATTENTION_HEAD
), f"Node type {self.node_type} is not NodeType.ATTENTION_HEAD"
assert self.q_k_or_v in {
ActivationLocationType.ATTN_QUERY,
ActivationLocationType.ATTN_KEY,
ActivationLocationType.ATTN_VALUE,
}
# TODO: consider subsuming this and the above into NodeIndex/ActivationIndex respectively
@dataclass(frozen=True)
class AttnSubActivationIndex(ActivationIndex):
"""An ActivationIndex that contains an extra piece of metadata, q_or_k,
which specifies whether the input to an attention head node should
be restricted to the portion going through the query or key"""
q_or_k: ActivationLocationType
def __post_init__(self) -> None:
assert self.activation_location_type.node_type == NodeType.ATTENTION_HEAD
assert self.q_or_k in {
ActivationLocationType.ATTN_QUERY,
ActivationLocationType.ATTN_KEY,
}
@immutable
class AblationSpec(CamelCaseBaseModel):
"""A specification for performing ablation on a model."""
index: MirroredActivationIndex
value: float
@unique
class AttentionTraceType(Enum):
Q = "Q"
K = "K"
QK = "QK"
"""Q times K"""
V = "V"
"""Allow gradient to flow through value vector; the attention write * gradient with respect to
some downstream node or the loss provides the scalar which is backpropagated"""
@immutable
class NodeAblation(CamelCaseBaseModel):
"""A specification for tracing an upstream node.
This data structure is used by the client. The server converts it to an AblationSpec.
"""
node_index: MirroredNodeIndex
value: float
class PreOrPostAct(str, Enum):
"""Specifies whether to trace from pre- or post-nonlinearity"""
PRE = "pre"
POST = "post"
@dataclass(frozen=True)
class TraceConfig:
"""This specifies a node from which to compute a backward pass, along with whether to trace from
pre- or post-nonlinearity, which subnodes to flow the gradient through in the case of an attention node,
and whether to detach the layer norm scale just before the activation (i.e. whether to flow gradients
through the layer norm scale parameter)."""
node_index: NodeIndex
pre_or_post_act: PreOrPostAct
detach_layer_norm_scale: bool
attention_trace_type: AttentionTraceType | None = None # applies only to attention heads
downstream_trace_config: "TraceConfig | None" = (
None # applies only to attention heads with attention_trace_type == AttentionTraceType.V
)
def __post_init__(self) -> None:
if self.node_index.node_type != NodeType.ATTENTION_HEAD:
assert self.attention_trace_type is None
if self.attention_trace_type != AttentionTraceType.V:
# only tracing through V supports a downstream node
assert self.downstream_trace_config is None
else:
if self.downstream_trace_config is not None:
# repeatedly tracing through V is not allowed; all other types of
# downstream trace configs are fine
assert self.downstream_trace_config.attention_trace_type != AttentionTraceType.V
# cfg is None -> a loss (function of logits) is assumed to be defined
@property
def node_type(self) -> NodeType:
return self.node_index.node_type
@property
def tensor_indices(self) -> AllOrOneIndices:
return make_all_or_one_from_tensor_indices(self.node_index.tensor_indices)
@property
def layer_index(self) -> LayerIndex:
return self.node_index.layer_index
@property
def pass_type(self) -> PassType:
return self.node_index.pass_type
@property
def ndim(self) -> int:
return self.node_index.ndim
def with_updated_index(
self,
**kwargs: Any,
) -> "TraceConfig":
return dataclasses.replace(
self,
node_index=self.node_index.with_updates(**kwargs),
)
@classmethod
def from_activation_index(
cls,
activation_index: ActivationIndex,
detach_layer_norm_scale: bool = DETACH_LAYER_NORM_SCALE,
) -> "TraceConfig":
node_index = NodeIndex.from_activation_index(activation_index)
match activation_index.activation_location_type:
case ActivationLocationType.MLP_PRE_ACT | ActivationLocationType.ATTN_QK_LOGITS:
pre_or_post_act = PreOrPostAct.PRE
case (
ActivationLocationType.MLP_POST_ACT
| ActivationLocationType.ATTN_QK_PROBS
| ActivationLocationType.ONLINE_AUTOENCODER_LATENT
):
pre_or_post_act = PreOrPostAct.POST
case _:
raise ValueError(
f"ActivationLocationType {activation_index.activation_location_type} not supported"
)
match node_index.node_type:
case NodeType.ATTENTION_HEAD:
attention_trace_type: AttentionTraceType | None = AttentionTraceType.QK
case _:
attention_trace_type = None
downstream_trace_config = None
return cls(
node_index=node_index,
pre_or_post_act=pre_or_post_act,
detach_layer_norm_scale=detach_layer_norm_scale,
attention_trace_type=attention_trace_type,
downstream_trace_config=downstream_trace_config,
)
@immutable
class MirroredTraceConfig(HashableBaseModel):
node_index: MirroredNodeIndex
pre_or_post_act: PreOrPostAct
detach_layer_norm_scale: bool
attention_trace_type: AttentionTraceType | None = None # applies only to attention heads
downstream_trace_config: "MirroredTraceConfig | None" = (
None # applies only to attention heads with attention_trace_type == AttentionTraceType.V
)
def to_trace_config(self) -> TraceConfig:
downstream_trace_config = (
self.downstream_trace_config.to_trace_config()
if self.downstream_trace_config is not None
else None
)
return TraceConfig(
node_index=self.node_index.to_node_index(),
pre_or_post_act=self.pre_or_post_act,
detach_layer_norm_scale=self.detach_layer_norm_scale,
attention_trace_type=self.attention_trace_type,
downstream_trace_config=downstream_trace_config,
)
@classmethod
def from_trace_config(cls, trace_config: TraceConfig) -> "MirroredTraceConfig":
mirrored_downstream_trace_config = (
cls.from_trace_config(trace_config.downstream_trace_config)
if trace_config.downstream_trace_config is not None
else None
)
return cls(
node_index=MirroredNodeIndex.from_node_index(trace_config.node_index),
pre_or_post_act=trace_config.pre_or_post_act,
detach_layer_norm_scale=trace_config.detach_layer_norm_scale,
attention_trace_type=trace_config.attention_trace_type,
downstream_trace_config=mirrored_downstream_trace_config,
)
@immutable
class NodeToTrace(CamelCaseBaseModel):
"""A specification for tracing a node.
This data structure is used by the client. The server converts it to an activation index and
an ablation spec.
In the case of tracing through attention value, there can be up to two NodeToTrace
objects: one upstream and one downstream. First, a gradient is computed with respect to the
downstream node. Then, the direct effect of the upstream (attention) node on that downstream
node is computed. Then, the gradient is computed with respect to that direct effect, propagated
through V
"""
node_index: MirroredNodeIndex
attention_trace_type: AttentionTraceType | None
downstream_trace_config: MirroredTraceConfig | None