neuron_explainer/activation_server/derived_scalar_computation.py (993 lines of code) (raw):
"""
This file contains code to add hooks relevant to a list of scalar derivers, and to run forward
passes with those hooks to populate a DerivedScalarStore with the value of the scalars in question.
"""
import gc
import time
from dataclasses import dataclass
from typing import Any, Callable, TypeVar
import torch
from fastapi import HTTPException
from neuron_explainer.activation_server.requests_and_responses import (
GroupId,
InferenceAndTokenData,
InferenceData,
LossFnConfig,
LossFnName,
)
from neuron_explainer.activations.derived_scalars import DerivedScalarType
from neuron_explainer.activations.derived_scalars.derived_scalar_store import (
DerivedScalarStore,
RawActivationStore,
)
from neuron_explainer.activations.derived_scalars.direct_effects import (
AttentionDirectEffectReconstituter,
)
from neuron_explainer.activations.derived_scalars.indexing import (
DETACH_LAYER_NORM_SCALE,
AblationSpec,
ActivationIndex,
AttentionTraceType,
TraceConfig,
make_python_slice_from_all_or_one_indices,
)
from neuron_explainer.activations.derived_scalars.logprobs import LogitReconstituter
from neuron_explainer.activations.derived_scalars.multi_group import (
MultiGroupDerivedScalarStore,
MultiGroupScalarDerivers,
)
from neuron_explainer.activations.derived_scalars.reconstituter_class import ActivationReconstituter
from neuron_explainer.activations.derived_scalars.scalar_deriver import (
ActivationLocationTypeAndPassType,
DstConfig,
ScalarDeriver,
)
from neuron_explainer.activations.hook_graph import AutoencoderHookGraph, TransformerHookGraph
from neuron_explainer.models import Transformer
from neuron_explainer.models.autoencoder_context import AutoencoderContext, MultiAutoencoderContext
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
Dimension,
LayerIndex,
NodeType,
PassType,
)
from neuron_explainer.models.model_context import (
InvalidTokenException,
ModelContext,
StandardModelContext,
)
from neuron_explainer.models.transformer import prep_input_and_right_pad_for_forward_pass
T = TypeVar("T")
# a nested dict of lists of tuples, where each tuple contains a DerivedScalarType and a
# DerivedScalarTypeConfig (the necessary information to specify a ScalarDeriver). The nested dict is
# keyed first by spec_name and then by group_id, where spec_name is the name associated with a
# ProcessingRequestSpec, and group_id refers to a GroupId enum value (each GroupId referring to a
# set of DSTs).
DstAndConfigsByProcessingStep = dict[str, dict[GroupId, list[tuple[DerivedScalarType, DstConfig]]]]
# a nested dict of lists of ScalarDerivers, parallel to and constructed from
# DstAndConfigsByProcessingStep
ScalarDeriversByProcessingStep = dict[str, dict[GroupId, list[ScalarDeriver]]]
# a nested dict of DerivedScalarStores, parallel to and constructed using
# ScalarDeriversByProcessingStep; also uses RawActivationStore as input (though note that
# RawActivationStore is a single object used for the entire nested dict)
DerivedScalarStoreByProcessingStep = dict[str, dict[GroupId, DerivedScalarStore]]
@dataclass(frozen=True)
class DerivedScalarComputationParams:
input_token_ints: list[int]
multi_group_scalar_derivers_by_processing_step: dict[str, MultiGroupScalarDerivers]
loss_fn_for_backward_pass: Callable[[torch.Tensor], torch.Tensor] | None
device_for_raw_activations: torch.device
ablation_specs: list[AblationSpec] | None
trace_config: TraceConfig | None
@property
def prompt_length(self) -> int:
return len(self.input_token_ints)
@property
def activation_location_type_and_pass_types(self) -> list[ActivationLocationTypeAndPassType]:
return list(
{
alt_and_pt
for mgsd in self.multi_group_scalar_derivers_by_processing_step.values()
for alt_and_pt in mgsd.activation_location_type_and_pass_types
}
)
def construct_logit_diff_loss_fn(
model_context: ModelContext,
target_tokens: list[str],
distractor_tokens: list[str],
subtract_mean: bool,
) -> Callable[[torch.Tensor], torch.Tensor]:
try:
target_tokens_as_ints = model_context.encode_token_str_list(target_tokens)
distractor_tokens_as_ints = model_context.encode_token_str_list(distractor_tokens)
except InvalidTokenException as e:
raise HTTPException(status_code=400, detail=str(e))
def loss_fn_for_backward_pass(output_logits: torch.Tensor) -> torch.Tensor:
assert output_logits.ndim == 3
nbatch, ntoken, nlogit = output_logits.shape
assert nbatch == 1
assert len(target_tokens_as_ints) > 0
target_mean = output_logits[:, -1, target_tokens_as_ints].mean(-1)
if len(distractor_tokens_as_ints) == 0:
loss = target_mean.mean() # average logits for target tokens
if subtract_mean:
loss -= output_logits[:, -1, :].mean()
return loss
else:
assert (
not subtract_mean
), "subtract_mean not a meaningful option when distractor_tokens is specified"
distractor_mean = output_logits[:, -1, distractor_tokens_as_ints].mean(-1)
return (
target_mean - distractor_mean
).mean() # difference between average logits for target and distractor tokens
return loss_fn_for_backward_pass
def construct_probs_loss_fn(
model_context: ModelContext, target_tokens: list[str]
) -> Callable[[torch.Tensor], torch.Tensor]:
try:
target_tokens_as_ints = model_context.encode_token_str_list(target_tokens)
except InvalidTokenException as e:
raise HTTPException(status_code=400, detail=str(e))
def loss_fn_for_backward_pass(output_logits: torch.Tensor) -> torch.Tensor:
assert output_logits.ndim == 3
output_probs = torch.softmax(output_logits, dim=-1)
nbatch, ntoken, nlogit = output_probs.shape
assert nbatch == 1
assert len(target_tokens_as_ints) > 0
target_sum = output_probs[:, -1, target_tokens_as_ints].sum(-1)
return target_sum.mean() # average summed probs for target tokens
return loss_fn_for_backward_pass
def construct_zero_loss_fn() -> Callable[[torch.Tensor], torch.Tensor]:
"""This loss function is used for running a backward pass that will be interrupted
by ablating some desired parameters. Parameters downstream of the ablated parameters
will have a gradient of 0, and parameters upstream of the ablated parameters will
in general have a non-zero gradient."""
def loss_fn_for_backward_pass(output_logits: torch.Tensor) -> torch.Tensor:
return 0.0 * output_logits.sum()
return loss_fn_for_backward_pass
def maybe_construct_loss_fn_for_backward_pass(
model_context: ModelContext, config: LossFnConfig | None
) -> Callable[[torch.Tensor], torch.Tensor] | None:
if config is None:
return None
else:
if config.name == LossFnName.LOGIT_DIFF:
assert config.target_tokens is not None
target_tokens = config.target_tokens
distractor_tokens = config.distractor_tokens or []
return construct_logit_diff_loss_fn(
model_context=model_context,
target_tokens=target_tokens,
distractor_tokens=distractor_tokens,
subtract_mean=False,
)
elif config.name == LossFnName.LOGIT_MINUS_MEAN:
assert config.target_tokens is not None
assert config.distractor_tokens is None
return construct_logit_diff_loss_fn(
model_context=model_context,
target_tokens=config.target_tokens,
distractor_tokens=[],
subtract_mean=True,
)
elif config.name == LossFnName.PROBS:
assert config.target_tokens is not None
assert config.distractor_tokens is None
target_tokens = config.target_tokens
return construct_probs_loss_fn(model_context=model_context, target_tokens=target_tokens)
elif config.name == LossFnName.ZERO:
return construct_zero_loss_fn()
else:
raise NotImplementedError(f"Unknown loss fn name: {config.name}")
ablatable_activation_location_type_by_node_type = {
NodeType.MLP_NEURON: ActivationLocationType.MLP_POST_ACT,
NodeType.ATTENTION_HEAD: ActivationLocationType.ATTN_QK_PROBS,
NodeType.RESIDUAL_STREAM_CHANNEL: ActivationLocationType.RESID_POST_MLP,
NodeType.AUTOENCODER_LATENT: ActivationLocationType.ONLINE_AUTOENCODER_LATENT,
}
def compute_derived_scalar_groups_for_input_token_ints(
model_context: StandardModelContext,
multi_autoencoder_context: MultiAutoencoderContext | None,
batched_ds_computation_params: list[DerivedScalarComputationParams],
) -> tuple[
list[dict[str, MultiGroupDerivedScalarStore]], list[InferenceData], list[RawActivationStore]
]:
"""This function runs a batched forward pass on the given batch of input token sequences, with
hooks added to the transformer to extract the activations needed to compute the scalars in
multi_group_scalar_derivers for each batch element. It then returns a batch of dicts of
DerivedScalarStores by group_id containing the relevant derived scalars for each token in
the input, as well as a batch of InferenceData objects containing tokenized inputs and other metadata,
and a batch of RawActivationStores, each of which was used to compute the respective
dict of DerivedScalarStores. These RawActivationStores can be used to compute additional derived scalars
post-hoc.
"""
(
batched_raw_activation_store,
batched_loss,
batched_activation_value_for_backward_pass,
batched_memory_used_before,
inference_time,
) = run_inference_and_populate_raw_store(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
batched_ds_computation_params=batched_ds_computation_params,
)
assert (
len(batched_raw_activation_store)
== len(batched_ds_computation_params)
== len(batched_loss)
== len(batched_activation_value_for_backward_pass)
== len(batched_memory_used_before)
)
batched_multi_group_ds_store_by_processing_step: list[
dict[str, MultiGroupDerivedScalarStore]
] = []
(
batched_multi_group_ds_store_by_processing_step,
batched_memory_used_after,
) = construct_ds_stores_from_raw(
batched_raw_activation_store,
batched_ds_computation_params,
)
batched_inference_data: list[InferenceData] = []
for (
loss,
activation_value_for_backward_pass,
memory_used_before,
memory_used_after,
) in zip(
batched_loss,
batched_activation_value_for_backward_pass,
batched_memory_used_before,
batched_memory_used_after,
):
inference_data = InferenceData(
inference_time=inference_time,
loss=loss,
activation_value_for_backward_pass=activation_value_for_backward_pass,
memory_used_before=memory_used_before,
memory_used_after=memory_used_after,
)
batched_inference_data.append(inference_data)
return (
batched_multi_group_ds_store_by_processing_step,
batched_inference_data,
batched_raw_activation_store,
)
def get_activation_index_and_reconstitute_activation_fn(
transformer: Transformer,
multi_autoencoder_context: MultiAutoencoderContext | None,
trace_config: TraceConfig,
) -> tuple[ActivationIndex, Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]]:
"""
This function returns the ActivationIndex corresponding to the preceding residual stream
index implied by the trace_config. It also returns a function taking one tensor, used to recompute
the activation specified by the trace_config from the residual stream.
"""
assert trace_config.attention_trace_type != AttentionTraceType.V
if trace_config.node_type.is_autoencoder_latent:
assert multi_autoencoder_context is not None
autoencoder_context = multi_autoencoder_context.get_autoencoder_context(
trace_config.node_type
)
assert autoencoder_context is not None
else:
autoencoder_context = None
act_reconstituter = ActivationReconstituter.from_trace_config(
transformer=transformer,
autoencoder_context=autoencoder_context,
trace_config=trace_config,
)
activation_index_for_reconstituter = act_reconstituter.get_residual_activation_index_for_node_index(
# convert trace_config to node index
trace_config.node_index
)
def reconstitute_activation_fn(
upstream_resid: torch.Tensor, _unused_downstream_resid: torch.Tensor | None
) -> torch.Tensor:
assert _unused_downstream_resid is None
reconstitute_activation = (
act_reconstituter.make_reconstitute_activation_fn_for_trace_config(
trace_config=trace_config
)
)
return reconstitute_activation(upstream_resid)
return activation_index_for_reconstituter, reconstitute_activation_fn
def get_activation_indices_and_reconstitute_direct_effect_fn(
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | None,
trace_config: TraceConfig,
loss_fn_for_backward_pass: Callable[[torch.Tensor], torch.Tensor] | None,
) -> tuple[
ActivationIndex,
ActivationIndex,
Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor],
]:
"""
For use with AttentionTraceType.V. This function returns the ActivationIndex corresponding to the
residual stream before the upstream (attention V) node as well as the ActivationIndex corresponding to the
residual stream before the downstream node, or before the loss. It also returns a function taking two arguments,
both residual stream activations, which computes the direct effect of the upstream activation on the downstream
activation.
If trace_config.attention_trace_config.downstream_trace_config is a normal TraceConfig, then it specifies a downstream
node's activation, which will be used to compute a gradient. If it is None, it is assumed that the output logits (in
their entirety) are being reconstructed instead, and a loss function computed, to compute the gradient.
"""
assert trace_config.attention_trace_type == AttentionTraceType.V
if trace_config.downstream_trace_config is None:
assert loss_fn_for_backward_pass is not None
assert isinstance(model_context, StandardModelContext) # for typechecking
logit_reconstituter = LogitReconstituter(
model_context=model_context,
detach_layer_norm_scale=DETACH_LAYER_NORM_SCALE,
)
downstream_activation_index_for_reconstituter = (
logit_reconstituter.get_residual_activation_index()
)
def reconstitute_gradient_fn(downstream_resid: torch.Tensor) -> torch.Tensor:
reconstitute_gradient = logit_reconstituter.make_reconstitute_gradient_of_loss_fn(
loss_fn=loss_fn_for_backward_pass
)
return reconstitute_gradient(downstream_resid)
else:
if trace_config.node_type.is_autoencoder_latent:
assert multi_autoencoder_context is not None
autoencoder_context = multi_autoencoder_context.get_autoencoder_context(
trace_config.node_type
)
assert autoencoder_context is not None
else:
autoencoder_context = None
act_reconstituter = ActivationReconstituter.from_trace_config(
transformer=model_context.get_or_create_model(),
autoencoder_context=autoencoder_context,
trace_config=trace_config.downstream_trace_config,
)
downstream_trace_config = trace_config.downstream_trace_config
downstream_activation_index_for_reconstituter = act_reconstituter.get_residual_activation_index_for_node_index(
# convert trace_config to node index
downstream_trace_config.node_index
)
def reconstitute_gradient_fn(
downstream_resid: torch.Tensor,
) -> torch.Tensor:
reconstitute_gradient_with_args = (
act_reconstituter.make_reconstitute_gradient_fn_for_trace_config(
trace_config=downstream_trace_config
)
)
return reconstitute_gradient_with_args(
downstream_resid,
downstream_trace_config.layer_index,
downstream_trace_config.pass_type,
)
assert trace_config.layer_index is not None
direct_effect_reconstituter = AttentionDirectEffectReconstituter(
model_context=model_context,
layer_indices=[trace_config.layer_index],
detach_layer_norm_scale=DETACH_LAYER_NORM_SCALE,
)
upstream_activation_index_for_reconstituter = direct_effect_reconstituter.get_residual_activation_index_for_node_index(
# convert trace_config to node index
trace_config.node_index
)
upstream_scalar_hook = direct_effect_reconstituter.make_scalar_hook_for_node_index(
trace_config.node_index
)
def reconstitute_direct_effect_fn(
upstream_resid: torch.Tensor,
downstream_resid: torch.Tensor | None,
) -> torch.Tensor:
assert downstream_resid is not None
gradient = reconstitute_gradient_fn(downstream_resid).detach()
activations = direct_effect_reconstituter.reconstitute_activations(
resid=upstream_resid,
grad=gradient,
layer_index=trace_config.layer_index,
pass_type=trace_config.pass_type,
)
return upstream_scalar_hook(activations)
return (
upstream_activation_index_for_reconstituter,
downstream_activation_index_for_reconstituter,
reconstitute_direct_effect_fn,
)
def replace_activation_index_using_reconstituter(
model_context: ModelContext,
multi_autoencoder_context: MultiAutoencoderContext | None,
batched_ds_computation_params: list[DerivedScalarComputationParams],
) -> tuple[
list[ActivationIndex | None],
list[ActivationIndex | None],
list[Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]],
]:
"""
Where trace_config occurs in batched_ds_computation_params, convert it to the
upstream_activation_index (and optionally also downstream_activation_index)
corresponding to the preceding residual stream required by a Reconstituter. Also
return a function, generated from the Reconstituter, to obtain the activation
corresponding to the original trace_config from the residual stream.
"""
batched_reconstitute_activation_fn: list[
Callable[[torch.Tensor, torch.Tensor | None], torch.Tensor]
] = []
batched_upstream_activation_index_to_grab: list[ActivationIndex | None] = []
batched_downstream_activation_index_to_grab: list[ActivationIndex | None] = []
for ds_computation_params_index in range(len(batched_ds_computation_params)):
ds_computation_params = batched_ds_computation_params[ds_computation_params_index]
trace_config = ds_computation_params.trace_config
if trace_config is not None:
if trace_config.attention_trace_type == AttentionTraceType.V:
(
upstream_activation_index_for_reconstituter,
downstream_activation_index_for_reconstituter,
reconstitute_activation_fn,
) = get_activation_indices_and_reconstitute_direct_effect_fn(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
trace_config=trace_config,
loss_fn_for_backward_pass=ds_computation_params.loss_fn_for_backward_pass,
)
else:
(
upstream_activation_index_for_reconstituter,
reconstitute_activation_fn,
) = get_activation_index_and_reconstitute_activation_fn(
transformer=model_context.get_or_create_model(),
multi_autoencoder_context=multi_autoencoder_context,
trace_config=trace_config,
)
downstream_activation_index_for_reconstituter = None
else:
upstream_activation_index_for_reconstituter = None
downstream_activation_index_for_reconstituter = None
def dummy_reconstitute_activation_fn(
resid: torch.Tensor, grad: torch.Tensor | None
) -> torch.Tensor:
raise NotImplementedError("This function should not be called")
reconstitute_activation_fn = dummy_reconstitute_activation_fn
if upstream_activation_index_for_reconstituter is not None:
assert (
upstream_activation_index_for_reconstituter.activation_location_type.node_type
== NodeType.RESIDUAL_STREAM_CHANNEL
)
if downstream_activation_index_for_reconstituter is not None:
assert (
downstream_activation_index_for_reconstituter.activation_location_type.node_type
== NodeType.RESIDUAL_STREAM_CHANNEL
)
batched_upstream_activation_index_to_grab.append(
upstream_activation_index_for_reconstituter
)
batched_downstream_activation_index_to_grab.append(
downstream_activation_index_for_reconstituter
)
batched_reconstitute_activation_fn.append(reconstitute_activation_fn)
return (
batched_upstream_activation_index_to_grab,
batched_downstream_activation_index_to_grab,
batched_reconstitute_activation_fn,
)
def run_inference_and_populate_raw_store(
model_context: StandardModelContext,
multi_autoencoder_context: MultiAutoencoderContext | None,
batched_ds_computation_params: list[DerivedScalarComputationParams],
) -> tuple[
list[RawActivationStore],
list[float | None],
list[float | None],
list[float | None],
float,
]:
"""
This populates a dict of ActivationsAndMetadata objects for each batch element, and returns
inference-related stats.
- batched_requested_activations_by_location_type_and_pass_type: stored activations
- batched_loss_floats: loss values for each batch element, if a loss function was specified
- batched_activation_value_for_backward_pass_floats: value of activation for which a backward pass was
computed for each batch element, if an activation index for backward pass was specified
- batched_memory_used_before: amount of memory allocated to the GPU before the forward pass for each batch element
- inference_time: time used for the forward pass, in seconds
"""
for params in batched_ds_computation_params:
trace_config = params.trace_config
if trace_config is not None:
if trace_config.node_type.is_autoencoder_latent:
assert multi_autoencoder_context is not None
assert (
multi_autoencoder_context.get_autoencoder_context(trace_config.node_type)
is not None
), f"Autoencoder context not found for {trace_config.node_type}"
transformer = model_context.get_or_create_model()
batched_input_token_ints = [params.input_token_ints for params in batched_ds_computation_params]
tokens_tensor, pad_tensor = prep_input_and_right_pad_for_forward_pass(
batched_input_token_ints, transformer.device
)
(
batched_upstream_activation_index_for_backward_pass,
batched_downstream_activation_index_for_backward_pass,
batched_reconstitute_activation_fn,
) = replace_activation_index_using_reconstituter(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
batched_ds_computation_params=batched_ds_computation_params,
)
batched_activation_index_for_backward_pass_by_name = [
{
"upstream": upstream_activation_index_for_backward_pass,
"downstream": downstream_activation_index_for_backward_pass,
}
for (
upstream_activation_index_for_backward_pass,
downstream_activation_index_for_backward_pass,
) in zip(
batched_upstream_activation_index_for_backward_pass,
batched_downstream_activation_index_for_backward_pass,
)
]
(
transformer_graph, # Stores the hooks
batched_requested_activations_by_location_type_and_pass_type, # Stores the activations from the hooks after the forward pass
batched_requested_attached_activations_for_backward_pass_by_name,
) = get_transformer_graph_hooks_and_activation_caches(
multi_autoencoder_context=multi_autoencoder_context,
batched_ds_computation_params=batched_ds_computation_params,
batched_activation_index_for_backward_pass_by_name=batched_activation_index_for_backward_pass_by_name,
)
assert len(batched_requested_activations_by_location_type_and_pass_type) == len(
batched_ds_computation_params
)
for (
requested_activations_by_location_type_and_pass_type,
ds_computation_params,
) in zip(
batched_requested_activations_by_location_type_and_pass_type,
batched_ds_computation_params,
):
pass_types = [
activation_location_type_and_pass_type.pass_type
for activation_location_type_and_pass_type in requested_activations_by_location_type_and_pass_type.keys()
]
if any(pass_type == PassType.BACKWARD for pass_type in pass_types):
assert (
ds_computation_params.loss_fn_for_backward_pass is not None
or ds_computation_params.trace_config is not None
), "loss_fn_for_backward_pass or trace_config must be defined if gradients are required"
batched_device_for_raw_activations = [
params.device_for_raw_activations for params in batched_ds_computation_params
]
t0 = time.time()
cuda_available = torch.cuda.is_available()
if cuda_available and any(
device.type == "cuda" for device in batched_device_for_raw_activations
):
torch.cuda.empty_cache()
batched_memory_used_before: list[float | None] = [
torch.cuda.memory_allocated(device) if device.type == "cuda" and cuda_available else None
for device in batched_device_for_raw_activations
]
logits, _ = transformer.forward(
tokens_tensor, pad=pad_tensor, hooks=transformer_graph.as_transformer_hooks()
)
batched_loss: list[torch.Tensor | None] = []
batched_activation_value_for_backward_pass: list[torch.Tensor | None] = []
for batch_index, (
ds_computation_params,
requested_attached_activation_for_backward_pass_by_name,
reconstitute_activation_fn,
activation_index_for_backward_pass_by_name,
) in enumerate(
zip(
batched_ds_computation_params,
batched_requested_attached_activations_for_backward_pass_by_name,
batched_reconstitute_activation_fn,
batched_activation_index_for_backward_pass_by_name,
)
):
loss_fn_for_backward_pass = ds_computation_params.loss_fn_for_backward_pass
if loss_fn_for_backward_pass is not None:
loss = loss_fn_for_backward_pass(logits[batch_index].unsqueeze(0))
else:
loss = None
if activation_index_for_backward_pass_by_name["upstream"] is not None:
assert requested_attached_activation_for_backward_pass_by_name["upstream"] is not None
activation_value_for_backward_pass = reconstitute_activation_fn(
requested_attached_activation_for_backward_pass_by_name["upstream"],
requested_attached_activation_for_backward_pass_by_name["downstream"],
)
else:
activation_value_for_backward_pass = None
batched_loss.append(loss)
batched_activation_value_for_backward_pass.append(activation_value_for_backward_pass)
populated_losses: list[torch.Tensor] = []
for loss, value in zip(batched_loss, batched_activation_value_for_backward_pass):
# backward pass is computed from value if it is not None, otherwise from loss
if value is not None:
populated_losses.append(value)
elif loss is not None:
populated_losses.append(loss)
if len(populated_losses):
assert all(isinstance(loss, torch.Tensor) for loss in populated_losses)
loss_sum = sum(populated_losses)
assert isinstance(loss_sum, torch.Tensor)
loss_sum.backward()
inference_time = time.time() - t0
batched_loss_floats: list[float | None] = [
loss.item() if loss is not None and not torch.isnan(loss) else None for loss in batched_loss
]
batched_activation_value_for_backward_pass_floats: list[float | None] = [
activation.item() if activation is not None else None
for activation in batched_activation_value_for_backward_pass
]
batched_raw_activation_store: list[RawActivationStore] = []
for (requested_activations_by_location_type_and_pass_type,) in zip(
batched_requested_activations_by_location_type_and_pass_type,
):
raw_activation_store = RawActivationStore.from_nested_dict_of_activations(
requested_activations_by_location_type_and_pass_type
)
batched_raw_activation_store.append(raw_activation_store)
assert (
len(batched_raw_activation_store)
== len(batched_loss_floats)
== len(batched_activation_value_for_backward_pass_floats)
== len(batched_memory_used_before)
== len(batched_ds_computation_params)
) # returns one batch element per input param setting
return (
batched_raw_activation_store,
batched_loss_floats,
batched_activation_value_for_backward_pass_floats,
batched_memory_used_before,
inference_time,
)
def construct_ds_stores_from_raw(
batched_raw_activation_store: list[RawActivationStore],
batched_ds_computation_params: list[DerivedScalarComputationParams],
) -> tuple[list[dict[str, MultiGroupDerivedScalarStore]], list[float | None]]:
batched_multi_group_ds_store_by_processing_step: list[
dict[str, MultiGroupDerivedScalarStore]
] = []
batched_memory_used_after: list[float | None] = []
for (
raw_activation_store,
ds_computation_params,
) in zip(
batched_raw_activation_store,
batched_ds_computation_params,
):
multi_group_scalar_derivers_by_processing_step = (
ds_computation_params.multi_group_scalar_derivers_by_processing_step
)
multi_group_ds_store_by_processing_step: dict[str, MultiGroupDerivedScalarStore] = {}
for (
spec_name,
multi_group_scalar_derivers,
) in multi_group_scalar_derivers_by_processing_step.items():
multi_group_ds_store_by_processing_step[
spec_name
] = MultiGroupDerivedScalarStore.derive_from_raw(
raw_activation_store, multi_group_scalar_derivers
)
batched_multi_group_ds_store_by_processing_step.append(
multi_group_ds_store_by_processing_step
)
device_for_raw_activations = ds_computation_params.device_for_raw_activations
memory_used_after = None
if torch.cuda.is_available() and device_for_raw_activations.type == "cuda":
gc.collect()
memory_used_after = torch.cuda.memory_allocated(device_for_raw_activations)
batched_memory_used_after.append(memory_used_after)
return (
batched_multi_group_ds_store_by_processing_step,
batched_memory_used_after,
)
def get_transformer_graph_hooks_and_activation_caches(
multi_autoencoder_context: MultiAutoencoderContext | None,
batched_ds_computation_params: list[DerivedScalarComputationParams],
batched_activation_index_for_backward_pass_by_name: list[dict[str, ActivationIndex | None]],
) -> tuple[
TransformerHookGraph,
list[dict[ActivationLocationTypeAndPassType, dict[LayerIndex, torch.Tensor]]],
list[dict[str, torch.Tensor | None]],
]:
"""This is a helper function that returns:
1. a TransformerHookGraph object containing hooks for the given
scalar derivers (this can be passed to Transformer using the as_transformer_hooks method to add
hooks to the transformer forward and backward pass)
2. dictionaries mapping each activation location type to the activations requested
(before it is filled with activations during forward passes), one dictionary per batch element
3. dictionaries each containing just one value, the scalar tensor on which the backward pass can be run (
unlike the activations in 2, this tensor is still attached to the pytorch model), one dictionary per batch element
"""
batched_activation_location_type_and_pass_types = [
params.activation_location_type_and_pass_types for params in batched_ds_computation_params
]
batched_device = [params.device_for_raw_activations for params in batched_ds_computation_params]
batched_ablation_specs = [params.ablation_specs for params in batched_ds_computation_params]
batched_prompt_lengths = [params.prompt_length for params in batched_ds_computation_params]
# This is a callable that can be used similarly to a Hooks object.
transformer_graph = TransformerHookGraph()
"""This step constructs the activation location types needed (for the case where they don't already exist). Any hooks to be added
to that location type can then be appended to transformer_graph in the normal way.
"""
# steps:
# 1. add autoencoder graph to transformer_graph (injected autoencoders specified in multi_autoencoder_context are
# hooked in a second set of forward hooks, called after ablating and saving hooks, and before activation grabbing hooks)
#
# for each batched element:
# 2. add ablating hooks
# 3. add activation grabbing hooks (storing without detaching, for backward pass). These come last, after ablating and saving hooks.
# 4. add saving hooks (storing with detaching)
#
# Because the autoencoder is hooked in a second set of hooks, followed by the grabbing hooks, they are always called last:
# - fwd hooks: ablate activations, save activations, ...
# - bwd hooks: ablate gradients, save gradients, ...
# - fwd2 hooks: autoencoder (convert to latent, ablate latents, grab latents, save latents, convert back to activations), grab activations, ...
# add autoencoder graph
if multi_autoencoder_context is not None:
has_multiple_autoencoders = (
len(multi_autoencoder_context.autoencoder_context_by_node_type) > 1
)
for (
node_type,
autoencoder_context,
) in multi_autoencoder_context.autoencoder_context_by_node_type.items():
subgraph = AutoencoderHookGraph(
autoencoder_context, is_one_of_multiple_autoencoders=has_multiple_autoencoders
)
subgraph_name = f"{node_type.value}"
transformer_graph.inject_subgraph(subgraph, subgraph_name)
(
batched_requested_activations_by_location_type_and_pass_type,
batched_requested_attached_activation_for_backward_pass_by_name,
) = ([], [])
assert (
len(batched_activation_location_type_and_pass_types)
== len(batched_ablation_specs)
== len(batched_activation_index_for_backward_pass_by_name)
== len(batched_prompt_lengths)
)
for i in range(len(batched_activation_location_type_and_pass_types)):
# add ablating hooks
ablation_spec = batched_ablation_specs[i]
if ablation_spec is not None:
add_ablating_hooks(transformer_graph, ablation_spec, batch_index=i)
# add activation grabbing hooks; the grabbed activations are stored in dicts, keyed by the
# string name assigned to them
requested_attached_activation_for_backward_pass_by_name: dict[str, torch.Tensor | None] = {}
for (
name,
activation_index_for_backward_pass,
) in batched_activation_index_for_backward_pass_by_name[i].items():
if activation_index_for_backward_pass is not None:
requested_attached_activation_for_backward_pass_by_name = add_grabbing_hook_for_backward_pass(
requested_attached_activation_for_backward_pass_by_name,
name,
transformer_graph=transformer_graph,
activation_index_for_backward_pass=activation_index_for_backward_pass,
batch_index=i,
append_to_fwd2=True, # append to fwd2 when using Reconstituter to obtain gradients, so that
# the preceding residual stream backward pass hooks can be called after running .backward()
# from the grabbed activation
)
else:
requested_attached_activation_for_backward_pass_by_name[name] = None
# add saving hooks
requested_activations_by_location_type_and_pass_type = add_saving_hooks(
transformer_graph=transformer_graph,
activation_location_type_and_pass_types=batched_activation_location_type_and_pass_types[
i
],
device=batched_device[i],
unpadded_prompt_length=batched_prompt_lengths[i],
batch_index=i,
)
batched_requested_activations_by_location_type_and_pass_type.append(
requested_activations_by_location_type_and_pass_type
)
batched_requested_attached_activation_for_backward_pass_by_name.append(
requested_attached_activation_for_backward_pass_by_name
)
return (
transformer_graph,
batched_requested_activations_by_location_type_and_pass_type,
batched_requested_attached_activation_for_backward_pass_by_name,
)
def create_activation_grabbing_hook_fn(
attached_activation_dict: dict[str, torch.Tensor | None],
name: str,
activation_index: ActivationIndex,
batch_index: int,
) -> tuple[Callable, ActivationLocationTypeAndPassType, LayerIndex, dict[str, torch.Tensor | None]]:
assert (
name not in attached_activation_dict
), f"Name {name} already exists in attached_activation_dict"
def activation_grabbing_hook_fn(act: torch.Tensor, **kwargs: Any) -> torch.Tensor:
indices: tuple[slice | int, ...] = (
batch_index,
) + make_python_slice_from_all_or_one_indices(
activation_index.tensor_indices
) # use the batch_index to index into the batch dimension
attached_activation_dict[name] = act[indices] # .clone()
return act
return (
activation_grabbing_hook_fn,
ActivationLocationTypeAndPassType(
activation_location_type=activation_index.activation_location_type,
pass_type=activation_index.pass_type,
),
activation_index.layer_index,
attached_activation_dict,
)
def add_grabbing_hook_for_backward_pass(
attached_activation_dict: dict[str, torch.Tensor | None],
name: str,
transformer_graph: TransformerHookGraph,
activation_index_for_backward_pass: ActivationIndex,
batch_index: int,
append_to_fwd2: bool = False,
) -> dict[str, torch.Tensor | None]:
"""This is a helper function that returns a TransformerHooks object containing hooks at naturally existing hook locations
(e.g. MLP post-activations, rather than autoencoder latent activations) for the given
scalar derivers, and a dictionary mapping each activation location type with a naturally existing hook to the activations requested
(before it is filled with activations during forward passes)."""
(
hook_fn,
activation_location_type_and_pass_type,
layer_index,
attached_activation_dict,
) = create_activation_grabbing_hook_fn(
attached_activation_dict, name, activation_index_for_backward_pass, batch_index=batch_index
)
transformer_graph.append(
activation_location_type_and_pass_type,
hook_fn,
layer_indices=layer_index,
append_to_fwd2=append_to_fwd2,
)
return attached_activation_dict
def create_ablating_hook_fn(
ablation_spec: AblationSpec,
batch_index: int,
) -> tuple[Callable, ActivationLocationTypeAndPassType, LayerIndex]:
activation_location_type_and_pass_type = ActivationLocationTypeAndPassType(
activation_location_type=ablation_spec.index.activation_location_type,
pass_type=ablation_spec.index.pass_type,
)
layer_index = ablation_spec.index.layer_index
def ablating_hook_fn(act: torch.Tensor, **kwargs: Any) -> torch.Tensor:
# Initial dimension is batch
act = act.clone()
python_slice_operator: tuple[slice | int, ...] = (
slice(batch_index, batch_index + 1),
) + make_python_slice_from_all_or_one_indices(ablation_spec.index.tensor_indices)
assert len(python_slice_operator) == len(act.shape), (
len(python_slice_operator),
python_slice_operator,
len(act.shape),
act.shape,
)
act[python_slice_operator] = ablation_spec.value
return act
return ablating_hook_fn, activation_location_type_and_pass_type, layer_index
def add_ablating_hooks(
transformer_graph: TransformerHookGraph,
ablation_specs: list[AblationSpec],
batch_index: int,
) -> None:
"""This is a helper function that returns a TransformerHooks object containing hooks at naturally existing hook locations
(e.g. MLP post-activations, rather than autoencoder latent activations) for the given
scalar derivers, and a dictionary mapping each activation location type with a naturally existing hook to the activations requested
(before it is filled with activations during forward passes)."""
for ablation_spec in ablation_specs:
(
hook_fn,
activation_location_type_and_pass_type,
layer_index,
) = create_ablating_hook_fn(ablation_spec, batch_index=batch_index)
transformer_graph.append(
activation_location_type_and_pass_type,
hook_fn,
layer_indices=layer_index,
)
def create_saving_hook_fn(
device: torch.device,
activation_location_type_and_pass_type: ActivationLocationTypeAndPassType,
unpadded_prompt_length: int,
batch_index: int,
) -> tuple[Callable, dict[LayerIndex, torch.Tensor]]:
requested_activations_by_layer_index = {}
def saving_hook_fn(act: torch.Tensor, **kwargs: Any) -> torch.Tensor:
layer_index = kwargs.get("layer", None)
# First dimension is batch, second dimension is sequence length. We truncate the sequence
# length to the unpadded prompt length. If the third dimension is also sequence length, we
# truncate that too.
shape_spec = (
activation_location_type_and_pass_type.activation_location_type.shape_spec_per_token_sequence
)
def get_slice_for_dim(dim: Dimension) -> slice:
if dim.is_sequence_token_dimension:
return slice(None, unpadded_prompt_length)
else:
return slice(None)
truncated_act = act[(batch_index,) + tuple([get_slice_for_dim(dim) for dim in shape_spec])]
requested_activations_by_layer_index[layer_index] = truncated_act.detach().to(device)
return act
return saving_hook_fn, requested_activations_by_layer_index
def add_saving_hooks(
transformer_graph: TransformerHookGraph,
activation_location_type_and_pass_types: list[ActivationLocationTypeAndPassType],
device: torch.device,
unpadded_prompt_length: int,
batch_index: int,
) -> dict[ActivationLocationTypeAndPassType, dict[LayerIndex, torch.Tensor]]:
"""This is a helper function that returns a TransformerHooks object containing hooks at naturally existing hook locations
(e.g. MLP post-activations, rather than autoencoder latent activations) for the given
scalar derivers, and a dictionary mapping each activation location type with a naturally existing hook to the activations requested
(before it is filled with activations during forward passes)."""
requested_activations_by_location_type_and_pass_type = {}
for activation_location_type_and_pass_type in activation_location_type_and_pass_types:
(
hook_fn,
requested_activations_by_location_type_and_pass_type[
activation_location_type_and_pass_type
],
) = create_saving_hook_fn(
device,
activation_location_type_and_pass_type=activation_location_type_and_pass_type,
unpadded_prompt_length=unpadded_prompt_length,
batch_index=batch_index,
)
transformer_graph.append(
activation_location_type_and_pass_type,
hook_fn,
)
return requested_activations_by_location_type_and_pass_type
def apply_default_dst_configs_to_dst_and_config_list(
model_context: StandardModelContext,
multi_autoencoder_context: MultiAutoencoderContext | None,
dst_and_config_list: list[tuple[DerivedScalarType, DstConfig | None]],
) -> list[tuple[DerivedScalarType, DstConfig]]:
def get_default_dst_config(
dst: DerivedScalarType,
) -> DstConfig:
return DstConfig(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
derive_gradients=not dst.requires_grad_for_forward_pass,
)
return [
(dst, config if config is not None else get_default_dst_config(dst))
for dst, config in dst_and_config_list
]
def get_ds_computation_params_for_prompt(
model_context: StandardModelContext,
autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None,
dst_and_config_list: list[
tuple[DerivedScalarType, DstConfig | None]
], # None -> default config for dst
prompt: str,
loss_fn_for_backward_pass: Callable[[torch.Tensor], torch.Tensor] | None,
trace_config: TraceConfig | None,
ablation_specs: list[AblationSpec],
) -> DerivedScalarComputationParams:
assert (loss_fn_for_backward_pass is None) or (trace_config is None)
multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
autoencoder_context
)
dst_and_config_list_with_default_config = apply_default_dst_configs_to_dst_and_config_list(
model_context, multi_autoencoder_context, dst_and_config_list
)
multi_group_scalar_derivers_by_processing_step = {
"dummy": MultiGroupScalarDerivers.from_dst_and_config_list(
dst_and_config_list_with_default_config
) # "dummy" is a placeholder processing step name
}
input_token_ints = model_context.encode(prompt)
return DerivedScalarComputationParams(
input_token_ints=input_token_ints,
multi_group_scalar_derivers_by_processing_step=multi_group_scalar_derivers_by_processing_step,
loss_fn_for_backward_pass=loss_fn_for_backward_pass,
device_for_raw_activations=model_context.device,
ablation_specs=ablation_specs,
trace_config=trace_config,
)
def get_derived_scalars_for_prompt(
model_context: StandardModelContext,
dst_and_config_list: list[
tuple[DerivedScalarType, DstConfig | None]
], # None -> default config for dst
prompt: str,
loss_fn_for_backward_pass: Callable[[torch.Tensor], torch.Tensor] | None = None,
trace_config: TraceConfig | None = None,
autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None = None,
ablation_specs: list[AblationSpec] = [],
) -> tuple[DerivedScalarStore, InferenceAndTokenData, RawActivationStore]:
"""
Lightweight function to populate a DerivedScalarStore given information specifying the prompt, loss function, and derived scalars to compute.
"""
multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
autoencoder_context
)
input_token_ints = model_context.encode(prompt)
input_token_strings = [model_context.decode_token(token_int) for token_int in input_token_ints]
dst_and_config_list_with_default_config = apply_default_dst_configs_to_dst_and_config_list(
model_context, multi_autoencoder_context, dst_and_config_list
)
multi_group_scalar_derivers_by_processing_step = {
"dummy": MultiGroupScalarDerivers.from_dst_and_config_list(
dst_and_config_list_with_default_config
)
} # "dummy" is a placeholder processing step name
ds_computation_params = DerivedScalarComputationParams(
input_token_ints=input_token_ints,
multi_group_scalar_derivers_by_processing_step=multi_group_scalar_derivers_by_processing_step,
loss_fn_for_backward_pass=loss_fn_for_backward_pass,
device_for_raw_activations=model_context.device,
trace_config=trace_config,
ablation_specs=ablation_specs,
)
batched_multi_group_ds_store_by_processing_step: list[dict[str, MultiGroupDerivedScalarStore]]
(
batched_multi_group_ds_store_by_processing_step,
batched_inference_data,
batched_raw_activation_store,
) = compute_derived_scalar_groups_for_input_token_ints(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
batched_ds_computation_params=[ds_computation_params],
)
ds_store = batched_multi_group_ds_store_by_processing_step[0]["dummy"].to_single_ds_store()
inference_data = batched_inference_data[0]
inference_and_token_data = InferenceAndTokenData(
**inference_data.dict(),
tokens_as_ints=input_token_ints,
tokens_as_strings=input_token_strings,
)
raw_activation_store = batched_raw_activation_store[0]
return ds_store, inference_and_token_data, raw_activation_store
def get_batched_derived_scalars_for_prompt(
model_context: StandardModelContext,
batched_dst_and_config_list: list[
list[tuple[DerivedScalarType, DstConfig | None]]
], # None -> default config for dst
batched_prompt: list[str],
loss_fn_for_backward_pass: Callable[[torch.Tensor], torch.Tensor] | None = None,
trace_config: TraceConfig | None = None,
autoencoder_context: MultiAutoencoderContext | AutoencoderContext | None = None,
ablation_specs: list[AblationSpec] = [],
) -> tuple[list[DerivedScalarStore], list[InferenceAndTokenData], list[RawActivationStore]]:
"""
Lightweight function to populate a DerivedScalarStore given information specifying the prompt, loss function, and derived scalars to compute.
"""
multi_autoencoder_context = MultiAutoencoderContext.from_context_or_multi_context(
autoencoder_context
)
assert len(batched_dst_and_config_list) == len(batched_prompt)
batched_ds_computation_params = [
get_ds_computation_params_for_prompt(
model_context=model_context,
autoencoder_context=autoencoder_context,
dst_and_config_list=dst_and_config_list,
prompt=prompt,
loss_fn_for_backward_pass=loss_fn_for_backward_pass,
trace_config=trace_config,
ablation_specs=ablation_specs,
)
for prompt, dst_and_config_list in zip(batched_prompt, batched_dst_and_config_list)
]
batched_multi_group_ds_store_by_processing_step: list[dict[str, MultiGroupDerivedScalarStore]]
(
batched_multi_group_ds_store_by_processing_step,
batched_inference_data,
batched_raw_activation_store,
) = compute_derived_scalar_groups_for_input_token_ints(
model_context=model_context,
multi_autoencoder_context=multi_autoencoder_context,
batched_ds_computation_params=batched_ds_computation_params,
)
batched_ds_store = [
multi_group_ds_store["dummy"].to_single_ds_store()
for multi_group_ds_store in batched_multi_group_ds_store_by_processing_step
]
batched_inference_and_token_data = []
for ds_computation_params, inference_data in zip(
batched_ds_computation_params, batched_inference_data
):
input_token_ints = ds_computation_params.input_token_ints
input_token_strings = [
model_context.decode_token(token_int) for token_int in input_token_ints
]
batched_inference_and_token_data.append(
InferenceAndTokenData(
**inference_data.dict(),
tokens_as_ints=input_token_ints,
tokens_as_strings=input_token_strings,
)
)
return batched_ds_store, batched_inference_and_token_data, batched_raw_activation_store