in neuron_explainer/activation_server/tdb_conversions.py [0:0]
def node_ablation_to_ablation_spec(node_ablation: NodeAblation) -> AblationSpec:
node_index = node_ablation.node_index
value = node_ablation.value
match node_index.node_type:
case NodeType.ATTENTION_HEAD:
activation_location_type = ActivationLocationType.ATTN_QK_PROBS
indices = [
get_sequence_token_index(node_index),
"All",
get_activation_index(node_index),
]
case NodeType.MLP_NEURON:
activation_location_type = ActivationLocationType.MLP_POST_ACT
indices = [
get_sequence_token_index(node_index),
get_activation_index(node_index),
]
case (
NodeType.AUTOENCODER_LATENT
| NodeType.MLP_AUTOENCODER_LATENT
| NodeType.ATTENTION_AUTOENCODER_LATENT
):
from neuron_explainer.activations.derived_scalars.autoencoder import (
get_autoencoder_alt_from_node_type,
)
activation_location_type = get_autoencoder_alt_from_node_type(node_index.node_type)
indices = [
get_sequence_token_index(node_index),
get_activation_index(node_index),
]
case _:
raise ValueError(f"Unknown node type {node_index.node_type}")
return AblationSpec(
index=MirroredActivationIndex(
activation_location_type=activation_location_type,
pass_type=PassType.FORWARD,
# mypy has trouble understanding that all of the values that can be assigned to indices
# match AllOrOneIndices.
tensor_indices=indices, # type: ignore
layer_index=node_index.layer_index,
),
value=value,
)