neuron_explainer/activations/derived_scalars/reconstituted.py (345 lines of code) (raw):
from typing import Any, Callable
import torch
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.activations.derived_scalars.indexing import AttentionTraceType, PreOrPostAct
from neuron_explainer.models import Autoencoder
from neuron_explainer.models.autoencoder_context import AutoencoderContext
from neuron_explainer.models.hooks import AttentionHooks, NormalizationHooks, TransformerHooks
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
LayerIndex,
NodeType,
PassType,
)
from neuron_explainer.models.transformer import Norm, Transformer, TransformerLayer
# scalar derivers that take residual stream as input and
# reconstitute activations such as attention post softmax and mlp post activations
def detach_hook(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
return x.detach()
def make_hook_getter() -> tuple[Callable[..., Any], Callable[[], Any]]:
"""
Returns a hook to append, and a function to retrieve the value of the hook.
The retrieve function must be called after the hook has been called.
"""
retrieve = {}
def hook(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
retrieve["value"] = x
return x
return hook, lambda: retrieve["value"]
def zero_batch_dim_hook(x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
"""This hook can be applied before unnecessary computations to save compute."""
return x[:0, ...]
def apply_layer_norm(
x: torch.Tensor, norm_module: Norm, detach_layer_norm_scale: bool
) -> torch.Tensor:
hooks = NormalizationHooks()
if detach_layer_norm_scale:
hooks = hooks.append_to_path("scale.fwd", detach_hook)
return norm_module(x, hooks=hooks)
def add_q_k_or_v_detach_hook(
hooks: AttentionHooks, q_k_or_v: ActivationLocationType | None
) -> None:
"""If q_k_or_v is None, leave everything attached. If q_k_or_v is not None,
then leave only the corresponding tensor (Q, K, or V) attached."""
if q_k_or_v is None:
return
if q_k_or_v != ActivationLocationType.ATTN_QUERY:
hooks.q.append_fwd(detach_hook)
if q_k_or_v != ActivationLocationType.ATTN_KEY:
hooks.k.append_fwd(detach_hook)
if q_k_or_v != ActivationLocationType.ATTN_VALUE:
hooks.v.append_fwd(detach_hook)
def apply_attn_pre_softmax(
transformer_layer: TransformerLayer,
q_k_or_v: ActivationLocationType | None,
resid_post_mlp: torch.Tensor,
detach_layer_norm_scale: bool,
) -> torch.Tensor:
attn_input = apply_layer_norm(
resid_post_mlp.unsqueeze(0),
transformer_layer.ln_1,
detach_layer_norm_scale=detach_layer_norm_scale,
) # add batch dimension
hooks = AttentionHooks()
add_q_k_or_v_detach_hook(hooks, q_k_or_v)
get_hook, get_attn = make_hook_getter()
hooks.qk_logits.append_fwd(get_hook)
# avoid v_out expense
hooks.v.append_fwd(zero_batch_dim_hook)
hooks.qk_logits.append_fwd(zero_batch_dim_hook)
transformer_layer.attn.forward(attn_input, hooks=hooks)
# remove batch dimension
return get_attn()[0]
def apply_mlp_act(
transformer_layer: TransformerLayer,
resid_post_attn: torch.Tensor,
detach_layer_norm_scale: bool,
) -> torch.Tensor:
pre_act = apply_mlp_pre_act(transformer_layer, resid_post_attn, detach_layer_norm_scale)
post_act = transformer_layer.mlp.act(pre_act)
return post_act
def apply_mlp_pre_act(
transformer_layer: TransformerLayer,
resid_post_attn: torch.Tensor,
detach_layer_norm_scale: bool,
) -> torch.Tensor:
post_ln_mlp = apply_layer_norm(
resid_post_attn.unsqueeze(0),
transformer_layer.ln_2,
detach_layer_norm_scale=detach_layer_norm_scale,
) # add batch dimension
pre_act = transformer_layer.mlp.in_layer(post_ln_mlp)
return pre_act.squeeze(0) # remove batch dimension
def apply_autoencoder_pre_latent(
transformer_layer: TransformerLayer,
autoencoder: Autoencoder,
resid: torch.Tensor,
autoencoder_dst: DerivedScalarType,
detach_layer_norm_scale: bool,
latent_slice: slice = slice(None),
) -> torch.Tensor:
"""
Given the residual stream activations preceding an autoencoder to be
applied to a given DST, first compute the activations of the DST (`to_be_encoded`)
and then apply the autoencoder to these activations (NOT INCLUDING the autoencoder nonlinearity),
and return the result.
"""
match autoencoder_dst:
case DerivedScalarType.MLP_POST_ACT:
to_be_encoded = apply_mlp_act(
transformer_layer,
resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case DerivedScalarType.RESID_DELTA_ATTN:
to_be_encoded = apply_resid_delta_attn(
transformer_layer,
resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case DerivedScalarType.RESID_DELTA_MLP:
to_be_encoded = apply_resid_delta_mlp(
transformer_layer,
resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case _:
raise NotImplementedError(autoencoder_dst.node_type)
return autoencoder.encode_pre_act(to_be_encoded, latent_slice=latent_slice)
def apply_autoencoder_latent(
transformer_layer: TransformerLayer,
autoencoder: Autoencoder,
resid: torch.Tensor,
autoencoder_dst: DerivedScalarType,
detach_layer_norm_scale: bool,
) -> torch.Tensor:
"""
Given the residual stream activations preceding an autoencoder to be
applied to a given DST, first compute the activations of the DST (`to_be_encoded`)
and then apply the autoencoder to these activations (INCLUDING the autoencoder nonlinearity),
and return the result.
"""
pre_latent = apply_autoencoder_pre_latent(
transformer_layer,
autoencoder,
resid,
autoencoder_dst,
detach_layer_norm_scale=detach_layer_norm_scale,
)
return autoencoder.activation(pre_latent)
def apply_resid_delta_attn(
transformer_layer: TransformerLayer, resid_post_mlp: torch.Tensor, detach_layer_norm_scale: bool
) -> torch.Tensor:
"""
Compute resid_delta_attn (the output of an attention layer) from the residual stream
just before the layer
"""
X = resid_post_mlp.unsqueeze(0)
hooks = TransformerHooks()
if detach_layer_norm_scale:
hooks = hooks.append_to_path("resid.torso.ln_attn.scale.fwd", detach_hook)
# empty hooks and KV cache to match type signature of transformer_layer methods
# second output is kv_cache, which is not used here
attn_delta, _ = transformer_layer.attn_block(X, kv_cache=None, pad=None, hooks=hooks)
return attn_delta.squeeze(0)
def apply_resid_delta_mlp(
transformer_layer: TransformerLayer,
resid_post_attn: torch.Tensor,
detach_layer_norm_scale: bool,
) -> torch.Tensor:
"""
Compute resid_delta_mlp (the output of an MLP layer) from the residual stream
just before the layer
"""
X = resid_post_attn.unsqueeze(0)
hooks = TransformerHooks()
if detach_layer_norm_scale:
hooks = hooks.append_to_path("resid.torso.ln_mlp.scale.fwd", detach_hook)
# empty hooks to match type signature of transformer_layer methods
mlp_delta = transformer_layer.mlp_block(X, hooks=hooks)
return mlp_delta.squeeze(0)
def make_reconstituted_activation_fn(
transformer: Transformer,
autoencoder_context: AutoencoderContext | None,
node_type: NodeType,
pre_or_post_act: PreOrPostAct | None,
detach_layer_norm_scale: bool,
attention_trace_type: AttentionTraceType | None,
) -> Callable[[torch.Tensor, LayerIndex, PassType], torch.Tensor]:
match node_type:
case NodeType.ATTENTION_HEAD:
match attention_trace_type:
case AttentionTraceType.QK:
q_or_k = None
case AttentionTraceType.Q:
q_or_k = ActivationLocationType.ATTN_QUERY
case AttentionTraceType.K:
q_or_k = ActivationLocationType.ATTN_KEY
case None:
raise ValueError(
"attention_trace_type must be specified for attention activations"
)
match pre_or_post_act:
case PreOrPostAct.PRE:
def act_fn(
resid: torch.Tensor,
layer_index: int | None,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_attn_pre_softmax(
transformer_layer=transformer.xf_layers[layer_index],
q_k_or_v=q_or_k,
resid_post_mlp=resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case PreOrPostAct.POST:
apply_attn_V_act = make_apply_attn_V_act(
transformer=transformer,
q_k_or_v=q_or_k,
detach_layer_norm_scale=detach_layer_norm_scale,
) # returns attn, V
def act_fn(
resid: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_attn_V_act(
resid,
layer_index,
pass_type,
)[
0
] # returns attn
case _:
raise NotImplementedError(pre_or_post_act)
case NodeType.MLP_NEURON:
match pre_or_post_act:
case PreOrPostAct.PRE:
def act_fn(
resid: torch.Tensor,
layer_index: int | None,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_mlp_pre_act(
transformer_layer=transformer.xf_layers[layer_index],
resid_post_attn=resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case PreOrPostAct.POST:
def act_fn(
resid: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_mlp_act(
transformer_layer=transformer.xf_layers[layer_index],
resid_post_attn=resid,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case _:
raise NotImplementedError(pre_or_post_act)
case (
NodeType.AUTOENCODER_LATENT
| NodeType.MLP_AUTOENCODER_LATENT
| NodeType.ATTENTION_AUTOENCODER_LATENT
):
assert autoencoder_context is not None
match pre_or_post_act:
case PreOrPostAct.PRE:
def act_fn(
resid: torch.Tensor,
layer_index: int | None,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_autoencoder_pre_latent(
transformer_layer=transformer.xf_layers[layer_index],
autoencoder=autoencoder_context.get_autoencoder(layer_index),
resid=resid,
autoencoder_dst=autoencoder_context.dst,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case PreOrPostAct.POST:
def act_fn(
resid: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
assert layer_index is not None
return apply_autoencoder_latent(
transformer_layer=transformer.xf_layers[layer_index],
autoencoder=autoencoder_context.get_autoencoder(layer_index),
resid=resid,
autoencoder_dst=autoencoder_context.dst,
detach_layer_norm_scale=detach_layer_norm_scale,
)
case _:
raise NotImplementedError(pre_or_post_act)
case _:
raise NotImplementedError(node_type)
return act_fn
def make_apply_attn_V_act(
transformer: Transformer,
q_k_or_v: ActivationLocationType | None,
detach_layer_norm_scale: bool,
) -> Callable[[torch.Tensor, LayerIndex, PassType], tuple[torch.Tensor, torch.Tensor]]:
"""Used in functions that require reconstituting some or all of the attention head
operation. Supports specifying a stop grad through all but one of Q, K, and V; or
if q_k_or_v is None, then all of Q, K, and V are backprop'ed through."""
def apply_attn_V_act(
resid: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> tuple[torch.Tensor, torch.Tensor]:
assert pass_type == PassType.FORWARD
transformer_layer = transformer.xf_layers[layer_index]
attn_input = apply_layer_norm(
resid.unsqueeze(0),
transformer_layer.ln_1,
detach_layer_norm_scale=detach_layer_norm_scale,
) # add batch dimension
hooks = AttentionHooks()
add_q_k_or_v_detach_hook(hooks, q_k_or_v)
get_hook, get_v = make_hook_getter()
hooks.v.append_fwd(get_hook)
get_hook, get_attn = make_hook_getter()
hooks.qk_probs.append_fwd(get_hook)
# avoid v_out expense
hooks.v.append_fwd(zero_batch_dim_hook)
hooks.qk_probs.append_fwd(zero_batch_dim_hook)
transformer_layer.attn.forward(attn_input, hooks=hooks)
# remove batch dimensions
return get_attn()[0], get_v()[0]
return apply_attn_V_act
def make_apply_logits(
transformer: Transformer,
detach_layer_norm_scale: bool,
) -> Callable[[torch.Tensor], torch.Tensor]:
def apply_logits(
resid_post_mlp: torch.Tensor,
) -> torch.Tensor:
"""
Input: (n_sequence_tokens, d_model) residual stream post-mlp activations at final layer.
Output: (n_sequence_tokens, n_vocab) logprobs for each token in the sequence.
"""
post_ln_f = apply_layer_norm(
resid_post_mlp.unsqueeze(0),
transformer.final_ln,
detach_layer_norm_scale=detach_layer_norm_scale,
) # add batch dimension
return transformer.unembed(post_ln_f).squeeze(0) # remove batch dimension
return apply_logits
def make_apply_logprobs(
transformer: Transformer,
detach_layer_norm_scale: bool,
) -> Callable[[torch.Tensor], torch.Tensor]:
def apply_logprobs(
resid_post_mlp: torch.Tensor,
) -> torch.Tensor:
"""
Input: (n_sequence_tokens, d_model) residual stream post-mlp activations at final layer.
Output: (n_sequence_tokens, n_vocab) logprobs for each token in the sequence.
"""
logits = make_apply_logits(transformer, detach_layer_norm_scale)(resid_post_mlp)
return torch.log_softmax(logits, dim=-1)
return apply_logprobs
def make_apply_autoencoder(
autoencoder_context: AutoencoderContext,
use_no_grad: bool = True, # use True to avoid keeping gradient info for autoencoder;
# TODO: consider deleting in favor of universal non-gradient-keeping at the outside of ScalarDeriver base functions
) -> Callable[[torch.Tensor, LayerIndex], torch.Tensor]:
"""
Returns a function that takes a tensor of activations and returns a tensor of the autoencoder
latent representation of each token.
"""
# TODO(sbills): Resolve the circular import between this file and attention.py.
from neuron_explainer.activations.derived_scalars.attention import make_reshape_fn
# reshape activations to be (n_tokens, n_inputs)
dst = autoencoder_context.dst
reshape_fn = make_reshape_fn(dst)
def apply_autoencoder(raw_activations: torch.Tensor, layer_index: LayerIndex) -> torch.Tensor:
assert (
layer_index in autoencoder_context.layer_indices
), f"Layer index {layer_index} not in {autoencoder_context.layer_indices}"
autoencoder = autoencoder_context.get_autoencoder(layer_index)
latent_activations = autoencoder.encode(reshape_fn(raw_activations))
return latent_activations # shape (n_tokens, n_latents)
if use_no_grad:
apply_autoencoder = torch.no_grad()(apply_autoencoder)
return apply_autoencoder