neuron_explainer/activations/derived_scalars/autoencoder.py (416 lines of code) (raw):
"""
This file contains code to compute derived scalars related to autoencoder latents post-hoc
(that is, from pre-existing MLP activations). Typically, the derived scalars consist of the
autoencoder latent activation multiplied by some other quantity.
"""
from typing import Callable
import torch
from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType
from neuron_explainer.activations.derived_scalars.direct_effects import (
convert_scalar_deriver_to_write_to_final_residual_grad,
)
from neuron_explainer.activations.derived_scalars.raw_activations import (
convert_scalar_deriver_to_write_norm,
convert_scalar_deriver_to_write_vector,
no_op_tensor_calculate_derived_scalar_fn,
)
from neuron_explainer.activations.derived_scalars.reconstituted import make_apply_autoencoder
from neuron_explainer.activations.derived_scalars.reconstituter_class import (
WriteLatentReconstituter,
)
from neuron_explainer.activations.derived_scalars.scalar_deriver import (
DerivedScalarSource,
DstConfig,
RawScalarSource,
ScalarDeriver,
)
from neuron_explainer.activations.derived_scalars.utils import detach_and_clone
from neuron_explainer.activations.derived_scalars.write_tensors import (
get_autoencoder_write_tensor_by_layer_index,
get_mlp_write_tensor_by_layer_index_with_autoencoder_context,
)
from neuron_explainer.models.autoencoder_context import AutoencoderContext
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
LayerIndex,
NodeType,
PassType,
)
def make_autoencoder_latent_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_autoencoder_latent_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
# import here to avoid circular import
from neuron_explainer.activations.derived_scalars.make_scalar_derivers import (
make_scalar_deriver,
)
layer_indices = dst_config.layer_indices
if layer_indices is None:
model_context = dst_config.get_model_context()
layer_indices = list(range(model_context.n_layers))
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
autoencoder_dst = autoencoder_context.dst
autoencoder_dst = maybe_convert_autoencoder_dst(autoencoder_dst)
scalar_deriver = make_scalar_deriver(autoencoder_dst, dst_config)
apply_autoencoder = make_apply_autoencoder(autoencoder_context)
def new_tensor_calculate_derived_scalar_fn(
derived_scalar_tensor: torch.Tensor,
layer_index: LayerIndex,
pass_type: PassType,
) -> torch.Tensor:
assert pass_type == PassType.FORWARD
return apply_autoencoder(derived_scalar_tensor, layer_index)
output_dst = DerivedScalarType.AUTOENCODER_LATENT.update_from_autoencoder_node_type(
node_type
)
new_scalar_deriver = scalar_deriver.apply_layerwise_transform_fn_to_output(
layerwise_transform_fn=new_tensor_calculate_derived_scalar_fn,
pass_type_to_transform=PassType.FORWARD,
output_dst=output_dst,
)
return new_scalar_deriver
return make_autoencoder_latent_scalar_deriver
def _make_autoencoder_latent_grad_wrt_input_scalar_deriver_helper(
dst_config: DstConfig,
output_dst: DerivedScalarType,
node_type: NodeType | None = None,
) -> ScalarDeriver:
"""Compute the gradient from a particular autoencoder latent, with respect to the
autoencoder input directions.
Requires:
>>> dst_config.layer_indices = [layer_index]
>>> dst_config.trace_config.tensor_indices = ("All", latent_index)
"""
# import here to avoid circular import
from neuron_explainer.activations.derived_scalars.make_scalar_derivers import (
make_scalar_deriver,
)
trace_config = dst_config.trace_config
assert trace_config is not None
assert trace_config.node_type in [
NodeType.AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT,
NodeType.ATTENTION_AUTOENCODER_LATENT,
NodeType.AUTOENCODER_LATENT_BY_TOKEN_PAIR,
]
assert trace_config.tensor_indices[0] == "All"
assert isinstance(trace_config.tensor_indices[1], int)
latent_index_for_grad = trace_config.tensor_indices[1]
layer_index = trace_config.layer_index
assert layer_index is not None
if dst_config.layer_indices is not None:
assert [layer_index] == dst_config.layer_indices
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
autoencoder_dst = autoencoder_context.dst
autoencoder_dst = maybe_convert_autoencoder_dst(autoencoder_dst)
scalar_deriver = make_scalar_deriver(autoencoder_dst, dst_config)
apply_autoencoder = make_apply_autoencoder(autoencoder_context, use_no_grad=False)
def new_tensor_calculate_derived_scalar_fn(
derived_scalar_tensor: torch.Tensor,
) -> torch.Tensor:
derived_scalar_tensor = detach_and_clone(derived_scalar_tensor, requires_grad=True)
latents = apply_autoencoder(derived_scalar_tensor, layer_index)
latents[:, latent_index_for_grad].sum(dim=0).backward() # sum over tokens
assert derived_scalar_tensor.grad is not None
return derived_scalar_tensor.grad
new_scalar_deriver = scalar_deriver.apply_transform_fn_to_output(
transform_fn=new_tensor_calculate_derived_scalar_fn,
pass_type_to_transform=PassType.FORWARD,
output_dst=output_dst,
)
return new_scalar_deriver
def make_autoencoder_latent_grad_wrt_residual_input_scalar_deriver(
dst_config: DstConfig,
node_type: NodeType | None = None,
) -> ScalarDeriver:
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
assert dst_config.trace_config is not None
latent_index = dst_config.trace_config.node_index
assert latent_index is not None
latent_reconstituter = WriteLatentReconstituter(autoencoder_context)
return latent_reconstituter.make_gradient_scalar_deriver_for_latent_index(
latent_index=latent_index,
dst_config=dst_config,
output_dst=DerivedScalarType.AUTOENCODER_LATENT_GRAD_WRT_RESIDUAL_INPUT,
)
def make_autoencoder_latent_grad_wrt_residual_input_scalar_source(
dst_config: DstConfig,
node_type: NodeType | None = None,
) -> DerivedScalarSource:
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
assert dst_config.trace_config is not None
latent_index = dst_config.trace_config.node_index
assert latent_index is not None
latent_reconstituter = WriteLatentReconstituter(autoencoder_context)
return latent_reconstituter.make_gradient_scalar_source_for_latent_index(
latent_index=latent_index,
dst_config=dst_config,
output_dst=DerivedScalarType.AUTOENCODER_LATENT_GRAD_WRT_RESIDUAL_INPUT,
)
def make_autoencoder_latent_grad_wrt_mlp_post_act_input_scalar_deriver(
dst_config: DstConfig,
node_type: NodeType | None = None,
) -> ScalarDeriver:
"""Output shape (n_tokens, n_neurons)"""
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
assert autoencoder_context.dst == DerivedScalarType.MLP_POST_ACT
return _make_autoencoder_latent_grad_wrt_input_scalar_deriver_helper(
dst_config,
output_dst=DerivedScalarType.AUTOENCODER_LATENT_GRAD_WRT_MLP_POST_ACT_INPUT,
node_type=node_type,
)
def maybe_convert_autoencoder_dst(autoencoder_dst: DerivedScalarType) -> DerivedScalarType:
if autoencoder_dst == DerivedScalarType.RESID_DELTA_MLP:
# TODO: Consider removing this workaround and using RESID_DELTA_MLP directly.
autoencoder_dst = DerivedScalarType.RESID_DELTA_MLP_FROM_MLP_POST_ACT
return autoencoder_dst
def make_autoencoder_write_norm_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_autoencoder_write_norm_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
model_context = dst_config.get_model_context()
dst = DerivedScalarType.AUTOENCODER_WRITE_NORM.update_from_autoencoder_node_type(node_type)
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
write_tensor_by_layer_index = get_autoencoder_write_tensor_by_layer_index(
autoencoder_context, model_context
)
scalar_deriver = make_autoencoder_latent_scalar_deriver_factory(node_type)(dst_config)
return convert_scalar_deriver_to_write_norm(
scalar_deriver, write_tensor_by_layer_index, dst
)
return make_autoencoder_write_norm_scalar_deriver
def get_autoencoder_alt_from_node_type(node_type: NodeType | None) -> ActivationLocationType:
"""Get the corresponding activation_location_type from a NodeType."""
return {
NodeType.AUTOENCODER_LATENT: ActivationLocationType.ONLINE_AUTOENCODER_LATENT,
NodeType.MLP_AUTOENCODER_LATENT: ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT,
NodeType.ATTENTION_AUTOENCODER_LATENT: ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT,
}[node_type or NodeType.AUTOENCODER_LATENT]
def make_online_autoencoder_latent_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_latent_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
dst = DerivedScalarType.ONLINE_AUTOENCODER_LATENT.update_from_autoencoder_node_type(
node_type
)
activation_location_type = get_autoencoder_alt_from_node_type(node_type)
return ScalarDeriver(
dst=dst,
dst_config=dst_config,
sub_scalar_sources=(
RawScalarSource(
activation_location_type=activation_location_type,
pass_type=PassType.FORWARD,
),
),
tensor_calculate_derived_scalar_fn=no_op_tensor_calculate_derived_scalar_fn,
)
return make_online_autoencoder_latent_scalar_deriver
def make_online_autoencoder_write_norm_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_write_norm_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
model_context = dst_config.get_model_context()
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
dst = DerivedScalarType.ONLINE_AUTOENCODER_WRITE_NORM.update_from_autoencoder_node_type(
node_type
)
write_tensor_by_layer_index = get_autoencoder_write_tensor_by_layer_index(
autoencoder_context, model_context
)
scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(
dst_config
)
return convert_scalar_deriver_to_write_norm(
scalar_deriver, write_tensor_by_layer_index, dst
)
return make_online_autoencoder_write_norm_scalar_deriver
def make_online_autoencoder_latentwise_write_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_latentwise_write_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
model_context = dst_config.get_model_context()
autoencoder_context = dst_config.get_autoencoder_context(node_type)
assert autoencoder_context is not None
dst = DerivedScalarType.ONLINE_AUTOENCODER_WRITE.update_from_autoencoder_node_type(
node_type
)
write_tensor_by_layer_index = get_autoencoder_write_tensor_by_layer_index(
autoencoder_context, model_context
)
scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(
dst_config
)
return convert_scalar_deriver_to_write_vector(
scalar_deriver, write_tensor_by_layer_index, dst
)
return make_online_autoencoder_latentwise_write_scalar_deriver
def make_online_autoencoder_write_to_final_residual_grad_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_write_to_final_residual_grad_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(
dst_config
)
dst = DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_RESIDUAL_GRAD.update_from_autoencoder_node_type(
node_type
)
return convert_scalar_deriver_to_write_to_final_residual_grad(
scalar_deriver, dst, use_existing_backward_pass_for_final_residual_grad=True
)
return make_online_autoencoder_write_to_final_residual_grad_scalar_deriver
def make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(
dst_config
)
dst = DerivedScalarType.ONLINE_AUTOENCODER_WRITE_TO_FINAL_ACTIVATION_RESIDUAL_GRAD.update_from_autoencoder_node_type(
node_type
)
return convert_scalar_deriver_to_write_to_final_residual_grad(
scalar_deriver, dst, use_existing_backward_pass_for_final_residual_grad=False
)
return make_online_autoencoder_write_to_final_activation_residual_grad_scalar_deriver
def make_online_autoencoder_act_times_grad_scalar_deriver_factory(
node_type: NodeType | None = None,
) -> Callable[[DstConfig], ScalarDeriver]:
def make_online_autoencoder_act_times_grad_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
act_scalar_deriver = make_online_autoencoder_latent_scalar_deriver_factory(node_type)(
dst_config
)
dst = DerivedScalarType.ONLINE_AUTOENCODER_ACT_TIMES_GRAD.update_from_autoencoder_node_type(
node_type
)
activity_location_type = get_autoencoder_alt_from_node_type(node_type)
return act_scalar_deriver.apply_layerwise_transform_fn_to_output_and_other_tensor(
layerwise_transform_fn=lambda act, grad, layer_index, pass_type: act * grad,
pass_type_to_transform=PassType.FORWARD, # act
other_scalar_source=RawScalarSource(
activation_location_type=activity_location_type,
pass_type=PassType.BACKWARD, # grad
),
output_dst=dst,
)
return make_online_autoencoder_act_times_grad_scalar_deriver
def make_online_autoencoder_error_scalar_deriver_factory(
activation_location_type: ActivationLocationType,
) -> Callable[[DstConfig], ScalarDeriver]:
required_node_type = {
ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR: NodeType.MLP_NEURON,
ActivationLocationType.ONLINE_RESIDUAL_MLP_AUTOENCODER_ERROR: NodeType.RESIDUAL_STREAM_CHANNEL,
ActivationLocationType.ONLINE_RESIDUAL_ATTENTION_AUTOENCODER_ERROR: NodeType.RESIDUAL_STREAM_CHANNEL,
}[activation_location_type]
required_autoencoder_node_type = {
ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR: NodeType.MLP_AUTOENCODER_LATENT,
ActivationLocationType.ONLINE_RESIDUAL_MLP_AUTOENCODER_ERROR: NodeType.MLP_AUTOENCODER_LATENT,
ActivationLocationType.ONLINE_RESIDUAL_ATTENTION_AUTOENCODER_ERROR: NodeType.ATTENTION_AUTOENCODER_LATENT,
}[activation_location_type]
dst = DerivedScalarType.from_activation_location_type(activation_location_type)
def make_online_autoencoder_error_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
autoencoder_context = dst_config.get_autoencoder_context(required_autoencoder_node_type)
assert autoencoder_context is not None
assert autoencoder_context.dst.node_type == required_node_type, (
autoencoder_context.dst,
required_node_type,
activation_location_type,
)
return ScalarDeriver(
dst=dst,
dst_config=dst_config,
sub_scalar_sources=(
RawScalarSource(
activation_location_type=activation_location_type,
pass_type=PassType.FORWARD,
),
),
tensor_calculate_derived_scalar_fn=no_op_tensor_calculate_derived_scalar_fn,
)
return make_online_autoencoder_error_scalar_deriver
def make_online_mlp_autoencoder_error_act_times_grad_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
act_scalar_deriver = make_online_autoencoder_error_scalar_deriver_factory(
ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR
)(dst_config)
return act_scalar_deriver.apply_layerwise_transform_fn_to_output_and_other_tensor(
layerwise_transform_fn=lambda act, grad, layer_index, pass_type: act * grad,
pass_type_to_transform=PassType.FORWARD, # act
other_scalar_source=RawScalarSource(
activation_location_type=ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR,
pass_type=PassType.BACKWARD, # grad
),
output_dst=DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_ACT_TIMES_GRAD,
)
def make_online_mlp_autoencoder_error_write_norm_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
model_context = dst_config.get_model_context()
autoencoder_context = dst_config.get_autoencoder_context(NodeType.MLP_AUTOENCODER_LATENT)
assert autoencoder_context is not None
write_tensor_by_layer_index = get_mlp_write_tensor_by_layer_index_with_autoencoder_context(
autoencoder_context, model_context
)
scalar_deriver = make_online_autoencoder_error_scalar_deriver_factory(
ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR
)(dst_config)
return convert_scalar_deriver_to_write_norm(
scalar_deriver,
write_tensor_by_layer_index,
DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_WRITE_NORM,
)
def make_online_mlp_autoencoder_error_write_to_final_residual_grad_scalar_deriver(
dst_config: DstConfig,
) -> ScalarDeriver:
scalar_deriver = make_online_autoencoder_error_scalar_deriver_factory(
ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR
)(dst_config)
return convert_scalar_deriver_to_write_to_final_residual_grad(
scalar_deriver,
DerivedScalarType.ONLINE_MLP_AUTOENCODER_ERROR_WRITE_TO_FINAL_RESIDUAL_GRAD,
use_existing_backward_pass_for_final_residual_grad=True,
)
# helpers for autoencoder gradient
def make_autoencoder_pre_act_encoder_derivative(
autoencoder_context: AutoencoderContext,
layer_index: int,
latent_index: int | None = None,
) -> Callable[[torch.Tensor], torch.Tensor]:
autoencoder = autoencoder_context.get_autoencoder(layer_index)
if isinstance(autoencoder.encoder, torch.nn.Linear):
# if the encoder is linear, then the derivative is just the encoder weight matrix
encoder = autoencoder.encoder.weight # shape (n_latents, n_inputs)
if latent_index is not None:
encoder = encoder[latent_index : latent_index + 1].clone()
# ^ need to clone to avoid MPS backend crash
def pre_act_encoder_derivative(autoencoder_input: torch.Tensor) -> torch.Tensor:
return autoencoder_input @ encoder.T
return pre_act_encoder_derivative
else:
raise NotImplementedError("Only implemented for linear encoder for now")
def make_autoencoder_activation_fn_derivative(
autoencoder_context: AutoencoderContext,
layer_index: int,
) -> Callable[[torch.Tensor], torch.Tensor]:
autoencoder = autoencoder_context.get_autoencoder(layer_index)
if _is_relu(autoencoder.activation):
# if the activation is ReLU, then the derivative is just a step function
def relu_derivative(post_activations: torch.Tensor) -> torch.Tensor:
return (post_activations > 0).to(post_activations.dtype)
return relu_derivative
else:
raise NotImplementedError("Only implemented for ReLU activation function for now")
def _is_relu(activation: Callable) -> bool:
"""More robust than isinstance(activation, torch.nn.ReLU), which doesn't always work."""
if isinstance(activation, torch.nn.ReLU):
return True
else:
test_input = torch.randn(10) * 10 ** torch.randn(10)
return torch.equal(activation(test_input), torch.nn.ReLU()(test_input))