neuron_explainer/activations/derived_scalars/multi_pass_scalar_deriver.py (254 lines of code) (raw):
"""
MultiPassScalarDerivers extend the functionality of ScalarDerivers by specifying how to derive a
scalar from some combination of the activations in multiple prompts.
The pairs of identical interfaces are:
ScalarDeriver:MultiPassScalarDeriver
ScalarSource:MultiPassScalarSource
RawActivationStore:MultiPassRawActivationStore
Both sets of objects can be used in populating a DerivedScalarStore. The intention is that it should
be possible to swap
derived_scalar_store = DerivedScalarStore.derive_from_raw(
multi_pass_raw_activation_store,
multi_pass_scalar_derivers,
)
for
batched_derived_scalar_store = [
DerivedScalarStore.derive_from_raw(
raw_activation_store,
scalar_derivers,
)
for scalar_derivers, raw_activation_store
in zip(batched_scalar_derivers, batched_raw_activation_store)
]
in order to compute derived scalars combining activations across multiple prompts in a batch.
Probable TODO: make an ABC, from which both ScalarDeriver and MultiPassScalarDeriver inherit
"""
from abc import ABC, abstractmethod
from enum import Enum
from neuron_explainer.activations.derived_scalars import DerivedScalarType
from neuron_explainer.activations.derived_scalars.derived_scalar_store import RawActivationStore
from neuron_explainer.activations.derived_scalars.locations import LayerIndexer, StaticLayerIndexer
from neuron_explainer.activations.derived_scalars.scalar_deriver import (
ActivationsAndMetadata,
DerivedScalarTypeAndPassType,
ScalarDeriver,
ScalarSource,
)
from neuron_explainer.models.model_component_registry import (
ActivationLocationType,
ActivationLocationTypeAndPassType,
LayerIndex,
LocationWithinLayer,
PassType,
)
class PromptId(Enum):
MAIN = "main"
BASELINE = "baseline"
class PromptCombo(Enum):
MAIN = "main"
BASELINE = "baseline"
SUBTRACT_BASELINE = "subtract_baseline"
@property
def required_prompt_ids(self) -> tuple[PromptId, ...]:
match self:
case PromptCombo.MAIN:
return (PromptId.MAIN,)
case PromptCombo.BASELINE:
return (PromptId.BASELINE,)
case PromptCombo.SUBTRACT_BASELINE:
return (PromptId.MAIN, PromptId.BASELINE)
case _:
raise NotImplementedError
def compute(
self, activations_by_prompt_id: dict[PromptId, ActivationsAndMetadata]
) -> ActivationsAndMetadata:
match self:
case PromptCombo.MAIN:
assert len(activations_by_prompt_id) == 1
main = activations_by_prompt_id.pop(PromptId.MAIN)
return main
case PromptCombo.BASELINE:
assert len(activations_by_prompt_id) == 1
baseline = activations_by_prompt_id.pop(PromptId.BASELINE)
return baseline
case PromptCombo.SUBTRACT_BASELINE:
main = activations_by_prompt_id.pop(PromptId.MAIN)
baseline = activations_by_prompt_id.pop(PromptId.BASELINE)
assert len(activations_by_prompt_id) == 0
return main - baseline
case _:
raise NotImplementedError
def derive_from_raw(
self,
multi_pass_raw_activation_store: "MultiPassRawActivationStore",
scalar_source: ScalarSource,
desired_layer_indices: (
list[LayerIndex] | None
), # indicates layer indices to keep; None indicates keep all
) -> ActivationsAndMetadata:
activations_by_prompt_id: dict[PromptId, ActivationsAndMetadata] = {}
for prompt_id in self.required_prompt_ids:
raw_activation_store = (
multi_pass_raw_activation_store.raw_activation_store_by_prompt_id[prompt_id]
)
activations_by_prompt_id[prompt_id] = scalar_source.derive_from_raw(
raw_activation_store, desired_layer_indices
)
return self.compute(activations_by_prompt_id)
class MultiPassScalarSource(ABC):
pass_type: PassType
layer_indexer: LayerIndexer
@property
@abstractmethod
def exists_by_default(self) -> bool:
# returns True if the activation is instantiated by default in a normal transformer forward pass
# this is False for activations related to autoencoders or for non-trivial derived scalars
pass
@property
@abstractmethod
def dst(self) -> DerivedScalarType:
pass
@property
def dst_and_pass_type(self) -> DerivedScalarTypeAndPassType:
return DerivedScalarTypeAndPassType(
self.dst,
self.pass_type,
)
@property
@abstractmethod
def sub_activation_location_type_and_pass_types(
self,
) -> tuple[ActivationLocationTypeAndPassType, ...]:
pass
@property
@abstractmethod
def location_within_layer(self) -> LocationWithinLayer | None:
pass
@property
def layer_index(self) -> LayerIndex:
"""Convenience method to get the single layer index associated with this ScalarSource, if such a single layer index
exists. Throws an error if it does not."""
assert isinstance(self.layer_indexer, StaticLayerIndexer), (
self.layer_indexer,
"ScalarSource.layer_index should only be called for ScalarSource StaticLayerIndexer",
)
return self.layer_indexer.layer_index
@abstractmethod
def derive_from_raw(
self,
multi_pass_raw_activation_store: "MultiPassRawActivationStore",
desired_layer_indices: (
list[LayerIndex] | None
), # indicates layer indices to keep; None indicates keep all
) -> ActivationsAndMetadata:
"""
See scalar_deriver.ScalarSource.derive_from_raw for explanation.
"""
pass
class SinglePromptComboScalarSource(MultiPassScalarSource):
"""
A SinglePromptComboScalarSource can be computed using some function
derived_scalar_A = f(derived_scalar_A_from_one_prompt[, derived_scalar_A_from_another_prompt, ...])
This is distinct from a MixedScalarSource, which is computed using some function of derived
scalars from SinglePromptComboScalarSources or other MixedScalarSources. For example, a
MixedScalarSource might be computed using some function:
derived_scalar_A = f(
g(sub_derived_scalar_B_from_one_prompt[, sub_derived_scalar_B_from_another_prompt, ...]),
h(sub_derived_scalar_C_from_one_prompt[, sub_derived_scalar_C_from_another_prompt, ...]),
)
"""
scalar_source: ScalarSource
prompt_combo: PromptCombo
def __init__(self, scalar_source: ScalarSource, prompt_combo: PromptCombo):
self.scalar_source = scalar_source
self.prompt_combo = prompt_combo
@property
def exists_by_default(self) -> bool:
return self.scalar_source.exists_by_default
@property
def dst(self) -> DerivedScalarType:
return self.scalar_source.dst
@property
def sub_activation_location_type_and_pass_types(
self,
) -> tuple[ActivationLocationTypeAndPassType, ...]:
return self.scalar_source.sub_activation_location_type_and_pass_types
@property
def location_within_layer(self) -> LocationWithinLayer | None:
return self.scalar_source.location_within_layer
def derive_from_raw(
self,
multi_pass_raw_activation_store: "MultiPassRawActivationStore",
desired_layer_indices: (
list[LayerIndex] | None
), # indicates layer indices to keep; None indicates keep all
) -> ActivationsAndMetadata:
return self.prompt_combo.derive_from_raw(
multi_pass_raw_activation_store, self.scalar_source, desired_layer_indices
)
class MixedScalarSource(MultiPassScalarSource):
multi_pass_scalar_deriver: "MultiPassScalarDeriver"
pass_type: PassType
layer_indexer: LayerIndexer
def __init__(
self,
multi_pass_scalar_deriver: "MultiPassScalarDeriver",
pass_type: PassType,
layer_indexer: LayerIndexer,
):
self.multi_pass_scalar_deriver = multi_pass_scalar_deriver
self.pass_type = pass_type
self.layer_indexer = layer_indexer
@property
def exists_by_default(self) -> bool:
return False
@property
def dst(self) -> DerivedScalarType:
return self.multi_pass_scalar_deriver.dst
@property
def sub_activation_location_type_and_pass_types(
self,
) -> tuple[ActivationLocationTypeAndPassType, ...]:
return self.multi_pass_scalar_deriver.get_sub_activation_location_type_and_pass_types()
@property
def location_within_layer(self) -> LocationWithinLayer | None:
return self.multi_pass_scalar_deriver.scalar_deriver.location_within_layer
def derive_from_raw(
self,
multi_pass_raw_activation_store: "MultiPassRawActivationStore",
desired_layer_indices: (
list[LayerIndex] | None
), # indicates layer indices to keep; None indicates keep all
) -> ActivationsAndMetadata:
return self.multi_pass_scalar_deriver.derive_from_raw(
multi_pass_raw_activation_store, self.pass_type
).apply_layer_indexer(self.layer_indexer, desired_layer_indices)
class MultiPassScalarDeriver:
scalar_deriver: ScalarDeriver
sub_scalar_sources: tuple[MultiPassScalarSource, ...]
def __init__(
self, scalar_deriver: ScalarDeriver, sub_scalar_sources: tuple[MultiPassScalarSource, ...]
):
self.scalar_deriver = scalar_deriver
self.sub_scalar_sources = sub_scalar_sources
assert [
sub_scalar_source.dst_and_pass_type for sub_scalar_source in sub_scalar_sources
] == list(scalar_deriver.get_sub_dst_and_pass_types())
@classmethod
def from_scalar_deriver_and_sub_prompt_combos(
cls,
scalar_deriver: ScalarDeriver,
sub_prompt_combos: tuple[PromptCombo, ...],
) -> "MultiPassScalarDeriver":
assert len(scalar_deriver.sub_scalar_sources) == len(sub_prompt_combos)
sub_scalar_sources = tuple(
[
SinglePromptComboScalarSource(scalar_source, prompt_combo)
for scalar_source, prompt_combo in zip(
scalar_deriver.sub_scalar_sources, sub_prompt_combos
)
]
)
return cls(scalar_deriver, sub_scalar_sources)
@property
def dst(self) -> DerivedScalarType:
return self.scalar_deriver.dst
@property
def derivable_pass_types(self) -> tuple[PassType, ...]:
return self.scalar_deriver.derivable_pass_types
def activations_and_metadata_calculate_derived_scalar_fn(
self, activation_data_tuple: tuple[ActivationsAndMetadata, ...], pass_type: PassType
) -> ActivationsAndMetadata:
return self.scalar_deriver.activations_and_metadata_calculate_derived_scalar_fn(
activation_data_tuple, pass_type
)
def get_sub_activation_location_type_and_pass_types(
self,
) -> tuple[ActivationLocationTypeAndPassType, ...]:
return self.scalar_deriver.get_sub_activation_location_type_and_pass_types()
def derive_from_raw(
self,
multi_pass_raw_activation_store: "MultiPassRawActivationStore",
pass_type: PassType,
) -> ActivationsAndMetadata:
activations_list = []
desired_layer_indices = None
for scalar_source in self.sub_scalar_sources:
activations = scalar_source.derive_from_raw(
multi_pass_raw_activation_store, desired_layer_indices
)
activations_list.append(activations)
if len(activations_list) == 1:
# match the layer_indices of the first activations_and_metadata object
desired_layer_indices = list(activations_list[0].layer_indices)
return self.activations_and_metadata_calculate_derived_scalar_fn(
tuple(activations_list),
pass_type,
)
# TODO: Run PromptCombo.derive_from_raw(scalar_source, raw_activation_store) as a part of
# MultiPassScalarSource.derive_from_raw(raw_activation_store)
class MultiPassRawActivationStore:
raw_activation_store_by_prompt_id: dict[PromptId, RawActivationStore]
def get_activations_and_metadata(
self,
prompt_id: PromptId,
activation_location_type: ActivationLocationType,
pass_type: PassType,
) -> ActivationsAndMetadata:
return self.raw_activation_store_by_prompt_id[prompt_id].get_activations_and_metadata(
activation_location_type, pass_type
)