neuron_explainer/activations/derived_scalars/postprocessing.py (556 lines of code) (raw):
from abc import ABC, abstractmethod
from typing import Any
import torch
from neuron_explainer.activations.derived_scalars.derived_scalar_store import DerivedScalarStore
from neuron_explainer.activations.derived_scalars.indexing import (
DETACH_LAYER_NORM_SCALE,
AttnSubNodeIndex,
DerivedScalarIndex,
MirroredNodeIndex,
NodeIndex,
PreOrPostAct,
TraceConfig,
)
from neuron_explainer.activations.derived_scalars.locations import (
get_previous_residual_dst_for_node_type,
)
from neuron_explainer.activations.derived_scalars.reconstituter_class import (
make_reconstituted_gradient_fn,
)
from neuron_explainer.activations.derived_scalars.scalar_deriver import DerivedScalarType, DstConfig
from neuron_explainer.models.autoencoder_context import (
AutoencoderContext,
MultiAutoencoderContext,
get_decoder_weight,
)
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
Dimension,
NodeType,
PassType,
WeightLocationType,
)
from neuron_explainer.models.model_context import (
ModelContext,
StandardModelContext,
get_embedding,
get_unembedding_with_ln_gain,
)
class DerivedScalarPostprocessor(ABC):
"""
A parent class for objects that perform postprocessing on specific tensors of derived scalars.
This postprocessing in general is assumed to require model weights, hence ModelContext. Optionally,
it might also require autoencoder weights, hence AutoencoderContext.
The important logic is in the postprocess() function, which takes a ds_index, and a value
that was produced using ds_store[ds_index] for some presumed ds_store.
This produces the postprocessed value. Both the value and the metadata in ds_index might be
required for performing the computation (e.g. the indices might be used to specify what part of a weight
tensor is required for performing the computation).
"""
_input_dst_by_node_type: dict[NodeType, DerivedScalarType]
# TODO: this should really match derived scalar types based on the compatibility of their indexing prefixes, rather
# than based on their node_types. Possibly this could take the form of: _input_dst_by_dimensions,
# and check whether a given derived scalar type's dimensions are a prefix of the input derived scalar type's dimensions.
# this could avoid the need for the _maybe_convert_input_node_type() method.
def _extract_tensor_for_postprocessing(
self,
node_index: NodeIndex | MirroredNodeIndex,
ds_store: DerivedScalarStore,
) -> tuple[DerivedScalarIndex, torch.Tensor, dict[str, Any]]:
"""
Finds the ds_index (asserted to be unique) in ds_store that is compatible with node_index,
(using self.convert_node_index_to_ds_index()),
and returns the ds_index and the corresponding derived_scalars tensor from ds_store, as well as
any additional kwargs required by self.postprocess_tensor().
This allows callers to access derived scalars from ds_store without having to check the derived
scalar types in the store.
"""
if isinstance(node_index, MirroredNodeIndex):
node_index = MirroredNodeIndex.to_node_index(node_index)
assert isinstance(node_index, NodeIndex), f"{node_index=}"
assert node_index.pass_type in ds_store.pass_types, (
f"Pass type {node_index.pass_type} not supported by this DerivedScalarStore; "
f"supported pass types are {ds_store.pass_types}"
)
ds_index = self.convert_node_index_to_ds_index(node_index)
kwargs = self.get_postprocess_tensor_kwargs(node_index, ds_store)
return ds_index, ds_store[ds_index], kwargs
@abstractmethod
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
"""For a specified node index, return the corresponding ds_index to submit as arguments to postprocess_tensor()"""
pass
def get_postprocess_tensor_kwargs(
self, node_index: NodeIndex, ds_store: DerivedScalarStore
) -> dict[str, Any]:
"""
Returns a dictionary of keyword arguments that should be passed to postprocess_tensor() for the given node index.
Varies based on child class, and can be empty.
"""
return {}
def postprocess(
self,
node_index: NodeIndex | MirroredNodeIndex,
ds_store: DerivedScalarStore,
) -> torch.Tensor:
"""
The primary function of each child class; takes a node index and a DerivedScalarStore assumed to contain
the DerivedScalarTypeAndPassType compatible with that node_type, and
returns a postprocessed value. The postprocessing steps in general depend on any fields of the ds_index,
as well as additional kwargs defined in self.get_postprocess_tensor_kwargs().
"""
ds_index, derived_scalars, kwargs = self._extract_tensor_for_postprocessing(
node_index, ds_store
)
return self.postprocess_tensor(ds_index, derived_scalars, **kwargs)
def postprocess_multiple_nodes(
self,
node_indices: list[NodeIndex],
ds_store: DerivedScalarStore,
) -> list[torch.Tensor]:
"""
A default implementation for postprocessing multiple nodes at once, which calls postprocess() for each node.
This can be overridden for performance reasons if a more efficient implementation is possible for a given
DerivedScalarPostprocessor.
"""
return [self.postprocess(node_index, ds_store) for node_index in node_indices]
@abstractmethod
def postprocess_tensor(
self,
ds_index: DerivedScalarIndex,
derived_scalars: torch.Tensor,
**kwargs: Any,
) -> torch.Tensor:
"""
An alternative entry point for postprocessing, which takes a ds_index and a derived_scalars tensor;
use this if you do not have access to a full DerivedScalarStore.
"""
...
def get_input_dst_and_config_list(
self,
requested_dst_and_config_list: list[tuple[DerivedScalarType, DstConfig]],
) -> list[tuple[DerivedScalarType, DstConfig]]:
"""
This matches the nodes reflected in the requested derived scalar types to the nodes supported by the postprocessor,
and returns a list of derived scalar types and configurations that should be collected into a DerivedScalarStore to
be passed to postprocess().
"""
requested_dsts = [dst for dst, _ in requested_dst_and_config_list]
dst_configs = [dst_config for _, dst_config in requested_dst_and_config_list]
requested_node_types = [dst.node_type for dst in requested_dsts]
assert len(requested_node_types) == len(
set(requested_node_types)
), "Requested derived scalar types must have unique node types"
input_dsts_and_configs = []
for i, node_type in enumerate(requested_node_types):
dst = self._input_dst_by_node_type.get(node_type)
if dst is not None:
input_dsts_and_configs.append((dst, dst_configs[i]))
return input_dsts_and_configs + self.get_constitutive_dst_and_config_list()
def get_constitutive_dst_and_config_list(self) -> list[tuple[DerivedScalarType, DstConfig]]:
"""
Returns a list of derived scalar types and configurations that should be collected into a DerivedScalarStore to
be passed to postprocess(), no matter what the requested derived scalar types are. Varies based on the child
class, and can be empty.
"""
return []
class ResidualWriteConverter(DerivedScalarPostprocessor):
"""
Converts activations to a direction in residual stream space, using write tensors. Valid activations
are MLP_POST_ACT and ATTN_WEIGHTED_VALUE (equal to post-softmax attention * value), and ONLINE_AUTOENCODER_LATENT
"""
"""
input dsts and node types accepted by this converter match except in the case of attention heads;
this is because we require more information (i.e. the entire value vector) to compute token space writes
node_type == NodeType.V_CHANNEL is a piece of metadata saying that the last index of a derived scalar
corresponds to a single dimension in v-space, or equivalently
that if you index all but the last index of the derived scalar, you get a vector in the v-space basis
"""
_input_dst_by_node_type: dict[NodeType, DerivedScalarType] = {
NodeType.MLP_NEURON: DerivedScalarType.MLP_POST_ACT,
NodeType.ATTENTION_HEAD: DerivedScalarType.ATTN_WEIGHTED_VALUE,
NodeType.LAYER: DerivedScalarType.RESID_POST_EMBEDDING,
}
def __init__(
self,
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None,
):
self._model_context = model_context
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
multi_autoencoder_context
)
if self._multi_autoencoder_context is not None:
if (
NodeType.MLP_AUTOENCODER_LATENT
in self._multi_autoencoder_context.autoencoder_context_by_node_type
):
self._input_dst_by_node_type[
NodeType.MLP_AUTOENCODER_LATENT
] = DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT
if (
NodeType.ATTENTION_AUTOENCODER_LATENT
in self._multi_autoencoder_context.autoencoder_context_by_node_type
):
self._input_dst_by_node_type[
NodeType.ATTENTION_AUTOENCODER_LATENT
] = DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT
if self._multi_autoencoder_context.has_single_autoencoder_context:
self._input_dst_by_node_type[
NodeType.AUTOENCODER_LATENT
] = DerivedScalarType.ONLINE_AUTOENCODER_LATENT
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
dst_for_write = self._input_dst_by_node_type[node_index.node_type]
supported_dsts = list(self._input_dst_by_node_type.values())
assert dst_for_write in supported_dsts, (
f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
f"supported node types are {supported_dsts}"
)
if node_index.node_type == NodeType.LAYER:
# remove the final, singleton dimension, which is not in the converted derived scalar type
assert len(node_index.tensor_indices) == 2
assert node_index.tensor_indices[1] == 0
updated_tensor_indices: tuple[int | None, ...] = node_index.tensor_indices[:-1]
else:
updated_tensor_indices = node_index.tensor_indices
ds_index = DerivedScalarIndex.from_node_index(
node_index.with_updates(
node_type=dst_for_write.node_type, tensor_indices=updated_tensor_indices
),
dst_for_write,
)
return ds_index
def _maybe_decode(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor
) -> torch.Tensor:
"""decodes if ds_index.dst is an autoencoder latent type, and if so, returns the
derived_scalars decoded by the decoder weight for the corresponding autoencoder. Otherwise, returns the
derived_scalars unchanged."""
if ds_index.dst.is_autoencoder_latent:
assert self._multi_autoencoder_context is not None
assert derived_scalars.ndim == 0
autoencoder = self._multi_autoencoder_context.get_autoencoder(
ds_index.layer_index, node_type=ds_index.dst.node_type
)
assert Dimension.AUTOENCODER_LATENTS in ds_index.tensor_index_by_dim
indices_for_decoder = (
ds_index.tensor_index_by_dim[Dimension.AUTOENCODER_LATENTS],
None,
)
slices_for_decoder: tuple[slice | int | None, ...] = tuple(
slice(None) if index is None else index for index in indices_for_decoder
)
decoder_weight = get_decoder_weight(autoencoder)[slices_for_decoder]
assert decoder_weight.ndim == 1
derived_scalars = derived_scalars * decoder_weight
return derived_scalars
def _get_output_weight(self, ds_index: DerivedScalarIndex) -> torch.Tensor:
if ds_index.dst.is_autoencoder_latent:
assert self._multi_autoencoder_context is not None
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
ds_index.dst.node_type
)
assert autoencoder_context is not None
output_dst = autoencoder_context.dst
elif ds_index.dst.node_type == NodeType.ATTENTION_HEAD:
output_dst = DerivedScalarType.ATTN_WEIGHTED_VALUE
else:
output_dst = ds_index.dst
_output_weight_by_dst: dict[DerivedScalarType, WeightLocationType] = {
DerivedScalarType.MLP_POST_ACT: WeightLocationType.MLP_TO_RESIDUAL,
DerivedScalarType.ATTN_WEIGHTED_VALUE: WeightLocationType.ATTN_TO_RESIDUAL,
}
assert output_dst in _output_weight_by_dst, f"{output_dst} must be in output weight dict"
output_weight_location_type = _output_weight_by_dst[output_dst]
weight_shape_spec = output_weight_location_type.shape_spec
weight_tensor_indices = tuple(
[ds_index.tensor_index_by_dim.get(dim, None) for dim in weight_shape_spec]
)
weight_tensor_slices: tuple[slice | int | None, ...] = tuple(
[slice(None) if index is None else index for index in weight_tensor_indices]
)
return self._model_context.get_weight(output_weight_location_type, ds_index.layer_index)[
weight_tensor_slices
]
def postprocess_tensor(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
# TODO: rationalize the setup for choosing the raw activations device by getting it from DerivedScalarTypeConfig,
# rather than permitting it as an argument to ScalarDeriver __init__.
# TODO: Derived scalar tensors sometimes haven't been detached yet! We work around that
# by detaching them here, but we should really just make sure they're always detached.
assert len(kwargs) == 0, f"Unexpected kwargs: {kwargs}"
derived_scalars = derived_scalars.to(self._model_context.device).detach()
# input can be either a scalar or a vector. In the case of e.g. attention heads,
# a vector worth of information is required to reconstruct the write to the residual stream
assert derived_scalars.ndim in {0, 1}
# 1. if an autoencoder latent, return the equivalent model activations; otherwise
# return the derived scalar unchanged
derived_scalars = self._maybe_decode(ds_index, derived_scalars)
# 2. find the output dst
if ds_index.dst.is_autoencoder_latent:
assert self._multi_autoencoder_context is not None
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
ds_index.dst.node_type
)
assert autoencoder_context is not None
output_dst = autoencoder_context.dst
else:
output_dst = ds_index.dst
# 3. convert from model activations to the residual stream write, unless it is already a residual stream write
if output_dst.node_type == NodeType.RESIDUAL_STREAM_CHANNEL:
assert derived_scalars.ndim == 1, f"{ds_index=}, {derived_scalars.shape=}"
return derived_scalars
else:
output_weight = self._get_output_weight(ds_index)
if derived_scalars.ndim == 0:
assert (
output_dst.node_type == NodeType.MLP_NEURON
), f"1-d activation expected only for MLP neurons, not {output_dst.node_type}"
derived_scalars = derived_scalars.unsqueeze(0)
assert output_weight.ndim == 1, f"{output_weight.shape=}"
output_weight = output_weight.unsqueeze(0)
else:
assert derived_scalars.ndim == 1
assert output_weight.ndim == 2
return torch.einsum("a,ad->d", derived_scalars, output_weight)
class TokenWriteConverter(DerivedScalarPostprocessor):
"""
Converts activations to a direction in token space, using the unembedding matrix. Valid activations
are MLP_POST_ACT and ATTN_WEIGHTED_VALUE (equal to post-softmax attention * value), and ONLINE_AUTOENCODER_LATENT
"""
def __init__(
self,
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None = None,
):
self._model_context = model_context
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
multi_autoencoder_context
)
self._residual_write_converter = ResidualWriteConverter(
model_context, multi_autoencoder_context
)
self._input_dst_by_node_type = self._residual_write_converter._input_dst_by_node_type
self._unemb_with_ln_gain = get_unembedding_with_ln_gain(self._model_context)
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
return self._residual_write_converter.convert_node_index_to_ds_index(node_index)
def postprocess_tensor(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
residual_write = self._residual_write_converter.postprocess_tensor(
ds_index, derived_scalars, **kwargs
)
# 3. convert from the residual stream write to the token-space write
unembedded_output = torch.einsum("d,dv->v", residual_write, self._unemb_with_ln_gain)
# 4. subtract the mean, since logprobs are invariant to adding a constant to all logits
mean_subtracted_unembedded_output = unembedded_output - unembedded_output.mean()
return mean_subtracted_unembedded_output
def postprocess_multiple_nodes(
self,
node_indices: list[NodeIndex],
ds_store: DerivedScalarStore,
) -> list[torch.Tensor]:
"""
First, postprocesses all the nodes with the residual write converter, then concatenates them
and applies the unembedding matrix to the concatenated tensor.
"""
residual_writes = self._residual_write_converter.postprocess_multiple_nodes(
node_indices, ds_store
)
concatenated_residual_writes = torch.stack(residual_writes, dim=0)
unembedded_output = torch.einsum(
"nd,dv->nv", concatenated_residual_writes, self._unemb_with_ln_gain
)
mean_subtracted_unembedded_output = unembedded_output - unembedded_output.mean(
dim=1, keepdim=True
)
split_unembedded_output = torch.split(mean_subtracted_unembedded_output, 1, dim=0)
list_of_tensors = [tensor.squeeze(0) for tensor in split_unembedded_output]
return list_of_tensors
def _get_residual_stream_tensor_indices_for_node(node_index: NodeIndex) -> tuple[int]:
"""For a given node index defining a point from which the gradient will be computed, this identifies the token
indices at which the gradient immediately before the node will be nonzero. For attention, in order for there to
be exactly one such token index, the gradient is computed through one of query/key/value, with a stopgrad
through the others. Depending on which of query/key/value is used, the token index will be either the query token
index or the key/value token index. For MLP neurons, the token index will be the token index of the neuron.
"""
# tensor_indices are expected to be tuple[int, ...], even if length 1
match node_index.node_type:
case NodeType.ATTENTION_HEAD:
# in the case of attention head reads, there are several possible ways to interpret the "read" direction
# - the gradient through just the query (at the query token)
# - the gradient through just the key (at the key/value token)
# - the gradient with respect to some function of the attention write, e.g. the attention write norm,
# through just the value (at the key/value token)
assert isinstance(node_index, AttnSubNodeIndex)
assert len(node_index.tensor_indices) == 3
match node_index.q_k_or_v:
case ActivationLocationType.ATTN_QUERY:
tensor_index = node_index.tensor_indices[1] # just the query token index
case ActivationLocationType.ATTN_KEY | ActivationLocationType.ATTN_VALUE:
tensor_index = node_index.tensor_indices[0] # just the key/value token index
case _:
raise ValueError(f"Unexpected q_k_or_v: {node_index.q_k_or_v}")
case (
NodeType.MLP_NEURON
| NodeType.AUTOENCODER_LATENT
| NodeType.MLP_AUTOENCODER_LATENT
| NodeType.ATTENTION_AUTOENCODER_LATENT
):
assert len(node_index.tensor_indices) == 2
tensor_index = node_index.tensor_indices[0] # just the token index
case _:
raise ValueError(f"Node type {node_index.node_type} not supported")
assert isinstance(tensor_index, int), (tensor_index, type(tensor_index))
return (tensor_index,)
class ResidualReadConverter(DerivedScalarPostprocessor):
"""
Converts activations to a gradient direction in residual stream space, by taking functions that recompute those
activations from the residual stream, and compute a backward pass on them. Valid activations
are PREVIOUS_LAYER_RESID_POST_MLP and RESID_POST_ATTN (the DSTs corresponding to residual stream locations that
precede attention heads and MLP neurons, respectively)
"""
def __init__(
self,
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None = None,
):
assert isinstance(model_context, StandardModelContext)
self._transformer = model_context.get_or_create_model()
self._device = model_context.device
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
multi_autoencoder_context
)
# TODO: support attention heads; this will require specifying q, k or v in the make_reconstituted_gradient_fn
self._input_dst_by_node_type: dict[NodeType, DerivedScalarType] = {}
self._input_dst_by_node_type[NodeType.MLP_NEURON] = get_previous_residual_dst_for_node_type(
NodeType.MLP_NEURON, None
)
self._input_dst_by_node_type[
NodeType.ATTENTION_HEAD
] = get_previous_residual_dst_for_node_type(NodeType.ATTENTION_HEAD, None)
if self._multi_autoencoder_context is not None:
# add the autoencoders listed in the multi_autoencoder_context, using their node types
for (
node_type,
autoencoder_context,
) in self._multi_autoencoder_context.autoencoder_context_by_node_type.items():
self._input_dst_by_node_type[node_type] = get_previous_residual_dst_for_node_type(
node_type, autoencoder_context.dst
)
# if there is only one autoencoder context, also add the "default" node type for backwards compatibility
if self._multi_autoencoder_context.has_single_autoencoder_context:
autoencoder_context = (
self._multi_autoencoder_context.get_single_autoencoder_context()
)
self._input_dst_by_node_type[
NodeType.AUTOENCODER_LATENT
] = get_previous_residual_dst_for_node_type(
NodeType.AUTOENCODER_LATENT, autoencoder_context.dst
)
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
if node_index.node_type == NodeType.ATTENTION_HEAD:
# see _get_residual_stream_tensor_indices_for_node for more information
# TODO: finish supporting attention heads
assert isinstance(node_index, AttnSubNodeIndex), (
node_index.node_type,
type(node_index),
)
assert node_index.q_k_or_v in {
ActivationLocationType.ATTN_QUERY,
ActivationLocationType.ATTN_KEY,
}
dst_for_computing_grad = self._input_dst_by_node_type[node_index.node_type]
supported_dsts = list(self._input_dst_by_node_type.values())
assert dst_for_computing_grad in supported_dsts, (
f"Node type {node_index.node_type} not supported by this DerivedScalarStore; "
f"supported node types are {supported_dsts}"
)
updated_tensor_indices = _get_residual_stream_tensor_indices_for_node(node_index)
# note: derived scalar indices do not have q_k_or_v associated to them, so we remove this field
updated_node_index = NodeIndex(
node_type=dst_for_computing_grad.node_type,
# Remove the activation index; the entire residual stream will be needed for computing
# the gradient.
tensor_indices=updated_tensor_indices,
layer_index=node_index.layer_index,
pass_type=node_index.pass_type,
)
return DerivedScalarIndex.from_node_index(
updated_node_index,
dst_for_computing_grad,
)
def get_postprocess_tensor_kwargs(
self, node_index: NodeIndex, _unused_ds_store: DerivedScalarStore
) -> dict[str, Any]:
return {"node_index": node_index}
def postprocess_tensor(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
# TODO: rationalize the setup for choosing the raw activations device by getting it from DerivedScalarTypeConfig,
# rather than permitting it as an argument to ScalarDeriver __init__.
# TODO: Derived scalar tensors sometimes haven't been detached yet! We work around that
# by detaching them here, but we should really just make sure they're always detached.
node_index = kwargs.pop("node_index")
assert len(kwargs) == 0, f"Unexpected kwargs: {kwargs}"
assert (
ds_index.pass_type == PassType.FORWARD
), "Residual read converter only supports forward pass"
derived_scalars = derived_scalars.to(self._device).detach()
# input should be a residual stream write (1-d)
assert derived_scalars.ndim == 1
node_index_with_singleton_first_dim = node_index.with_updates(
tensor_indices=(0,) + node_index.tensor_indices[1:]
)
trace_config = TraceConfig(
node_index=node_index_with_singleton_first_dim,
pre_or_post_act=PreOrPostAct.PRE,
detach_layer_norm_scale=DETACH_LAYER_NORM_SCALE,
)
# 1. create the function that computes the residual stream gradient
if trace_config.node_type.is_autoencoder_latent:
assert self._multi_autoencoder_context is not None
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
trace_config.node_type
)
assert autoencoder_context is not None
else:
autoencoder_context = None
reconstitute_gradient = make_reconstituted_gradient_fn(
transformer=self._transformer,
autoencoder_context=autoencoder_context,
trace_config=trace_config,
)
# 2. apply the function to the residual stream vector to get the residual stream gradient ("read" vector)
residual_read = reconstitute_gradient(
derived_scalars[None], ds_index.layer_index, PassType.FORWARD
)[
0
] # add and then remove token dimension for compat with reconstitute_gradient
return residual_read
class TokenReadConverter(DerivedScalarPostprocessor):
"""
Converts activations to a direction in token space, by computing a gradient as in ResidualReadConverter,
and projecting it into token space using the embedding matrix. Valid activations are
PREVIOUS_LAYER_RESID_POST_MLP and RESID_POST_ATTN (the DSTs corresponding to residual stream locations that
precede attention heads and MLP neurons, respectively)
"""
def __init__(
self,
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None = None,
):
self._model_context = model_context
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
multi_autoencoder_context
)
self._residual_read_converter = ResidualReadConverter(
model_context, multi_autoencoder_context
)
self._input_dst_by_node_type = self._residual_read_converter._input_dst_by_node_type
self._emb = get_embedding(self._model_context)
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
return self._residual_read_converter.convert_node_index_to_ds_index(node_index)
def get_postprocess_tensor_kwargs(
self, node_index: NodeIndex, _unused_ds_store: DerivedScalarStore
) -> dict[str, Any]:
return self._residual_read_converter.get_postprocess_tensor_kwargs(
node_index, _unused_ds_store
)
def postprocess_tensor(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
residual_read = self._residual_read_converter.postprocess_tensor(
ds_index, derived_scalars, **kwargs
)
# 3. convert from the residual stream read to the token-space read
return torch.einsum("d,vd->v", residual_read, self._emb)
def postprocess_multiple_nodes(
self,
node_indices: list[NodeIndex],
ds_store: DerivedScalarStore,
) -> list[torch.Tensor]:
"""
First, postprocesses all the nodes with the residual read converter, then concatenates them
and applies the embedding matrix to the concatenated tensor.
"""
residual_reads = self._residual_read_converter.postprocess_multiple_nodes(
node_indices, ds_store
)
concatenated_residual_reads = torch.stack(residual_reads, dim=0)
embedded_output = torch.einsum("nd,vd->nv", concatenated_residual_reads, self._emb)
split_unembedded_output = torch.split(embedded_output, 1, dim=0)
list_of_tensors = [tensor.squeeze(0) for tensor in split_unembedded_output]
return list_of_tensors
class TokenPairAttributionConverter(DerivedScalarPostprocessor):
"""
Converts activations of an attention-write autoencoder, to compute attribution to each token pair.
"""
_input_dst_by_node_type: dict[NodeType, DerivedScalarType] = {
NodeType.ATTENTION_AUTOENCODER_LATENT: DerivedScalarType.ATTENTION_AUTOENCODER_LATENT,
}
def __init__(
self,
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None,
num_tokens_attended_to: int,
):
self._model_context = model_context
self._multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
multi_autoencoder_context
)
self.num_tokens_attended_to = num_tokens_attended_to
def postprocess(
self,
node_index: NodeIndex | MirroredNodeIndex,
ds_store: DerivedScalarStore,
) -> torch.Tensor:
if node_index.node_type not in self._input_dst_by_node_type:
raise ValueError(f"Node type {node_index.node_type} not supported")
elif self._multi_autoencoder_context is not None:
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
node_index.node_type
)
if autoencoder_context is None:
raise ValueError(
f"No autoencoder context found for node type {node_index.node_type}."
)
if autoencoder_context.dst != DerivedScalarType.RESID_DELTA_ATTN:
raise ValueError(
"Autoencoder context found, but derived scalar type is not RESID_DELTA_ATTN."
)
# otherwise proceed
ds_index, derived_scalars, kwargs = self._extract_tensor_for_postprocessing(
node_index, ds_store
)
return self.postprocess_tensor(ds_index, derived_scalars, **kwargs)
def convert_node_index_to_ds_index(self, node_index: NodeIndex) -> DerivedScalarIndex:
dst = self._input_dst_by_node_type[node_index.node_type]
ds_index = DerivedScalarIndex.from_node_index(
node_index.with_updates(
node_type=dst.node_type, tensor_indices=node_index.tensor_indices
),
dst,
)
return ds_index
def postprocess_tensor(
self, ds_index: DerivedScalarIndex, derived_scalars: torch.Tensor, **kwargs: Any
) -> torch.Tensor:
from neuron_explainer.activations.derived_scalars.autoencoder import (
make_autoencoder_activation_fn_derivative,
make_autoencoder_pre_act_encoder_derivative,
)
attn_write_sum_heads = kwargs.pop("attn_write_sum_heads")
assert len(kwargs) == 0, f"Unexpected kwargs: {kwargs}"
layer_index = ds_index.layer_index
token_index, latent_index = ds_index.tensor_indices
assert self._multi_autoencoder_context is not None
autoencoder_context = self._multi_autoencoder_context.get_autoencoder_context(
NodeType.ATTENTION_AUTOENCODER_LATENT
)
assert autoencoder_context is not None
assert layer_index is not None
# compute the activation function derivative
activation_fn_derivative = make_autoencoder_activation_fn_derivative(
autoencoder_context, layer_index
)
latent_activation = derived_scalars # (,)
d_latent_d_pre_act = activation_fn_derivative(latent_activation) # (,)
if d_latent_d_pre_act == 0:
raise ValueError("Latent is inactive.")
# compute the encoder derivative
pre_act_encoder_derivative = make_autoencoder_pre_act_encoder_derivative(
autoencoder_context, layer_index, latent_index=latent_index
)
n_tokens_attended_to, d_model = attn_write_sum_heads.shape # already indexed by token_index
projection = pre_act_encoder_derivative(attn_write_sum_heads) # (n_tokens_attended_to, 1)
projection = projection[:, 0] # (n_tokens_attended_to,)
direct_write_to_latents = projection * d_latent_d_pre_act # (n_tokens_attended_to, )
# make sure the result has one dimension, because we use zero-dimension when the postprocessor
# is not supported (return torch.tensor(torch.nan))
assert direct_write_to_latents.ndim >= 1
return direct_write_to_latents
def get_constitutive_dst_and_config_list(self) -> list[tuple[DerivedScalarType, DstConfig]]:
return [
(
DerivedScalarType.ATTN_WRITE_SUM_HEADS,
DstConfig(
model_context=self._model_context,
),
)
]
def get_postprocess_tensor_kwargs(
self, node_index: NodeIndex, ds_store: DerivedScalarStore
) -> dict[str, Any]:
sequence_token_index = node_index.tensor_indices[0]
layer_index = node_index.layer_index
attn_write_sum_heads = ds_store[
DerivedScalarIndex(
dst=DerivedScalarType.ATTN_WRITE_SUM_HEADS,
layer_index=layer_index,
pass_type=PassType.FORWARD,
tensor_indices=(sequence_token_index, None),
)
]
return {"attn_write_sum_heads": attn_write_sum_heads}