def node_ablation_to_ablation_spec()

in neuron_explainer/activation_server/tdb_conversions.py [0:0]


def node_ablation_to_ablation_spec(node_ablation: NodeAblation) -> AblationSpec:
    node_index = node_ablation.node_index
    value = node_ablation.value

    match node_index.node_type:
        case NodeType.ATTENTION_HEAD:
            activation_location_type = ActivationLocationType.ATTN_QK_PROBS
            indices = [
                get_sequence_token_index(node_index),
                "All",
                get_activation_index(node_index),
            ]
        case NodeType.MLP_NEURON:
            activation_location_type = ActivationLocationType.MLP_POST_ACT
            indices = [
                get_sequence_token_index(node_index),
                get_activation_index(node_index),
            ]
        case (
            NodeType.AUTOENCODER_LATENT
            | NodeType.MLP_AUTOENCODER_LATENT
            | NodeType.ATTENTION_AUTOENCODER_LATENT
        ):
            from neuron_explainer.activations.derived_scalars.autoencoder import (
                get_autoencoder_alt_from_node_type,
            )

            activation_location_type = get_autoencoder_alt_from_node_type(node_index.node_type)

            indices = [
                get_sequence_token_index(node_index),
                get_activation_index(node_index),
            ]
        case _:
            raise ValueError(f"Unknown node type {node_index.node_type}")

    return AblationSpec(
        index=MirroredActivationIndex(
            activation_location_type=activation_location_type,
            pass_type=PassType.FORWARD,
            # mypy has trouble understanding that all of the values that can be assigned to indices
            # match AllOrOneIndices.
            tensor_indices=indices,  # type: ignore
            layer_index=node_index.layer_index,
        ),
        value=value,
    )