neuron_explainer/activations/hook_graph.py (245 lines of code) (raw):

""" This module contains classes for injecting hooks into a Transformer using the ActivationLocationType and PassType ontology. These produce activation location types that would not have otherwise existed. For example, the AutoencoderHookGraph is necessary for the ActivationLocationType.ONLINE_AUTOENCODER_LATENT location to exist. """ from abc import ABC from copy import deepcopy from typing import Any, Callable, Mapping, cast import torch from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.models.autoencoder_context import AutoencoderContext from neuron_explainer.models.hooks import ( AtLayers, AutoencoderHooks, HookCollection, Hooks, TransformerHooks, ) from neuron_explainer.models.inference_engine_type_registry import ( InferenceEngineType, get_hook_location_type_for_activation_location_type, standard_model_activation_location_types, ) from neuron_explainer.models.model_component_registry import ( ActivationLocationType, ActivationLocationTypeAndPassType, LayerIndex, PassType, ) def unflatten(f: Callable) -> Callable: def _f(x: torch.Tensor) -> torch.Tensor: return f(x.reshape(-1, x.shape[-1])).reshape(x.shape[0], x.shape[1], -1) return _f def _append_to_hook_collection_using_string_list( hook_collection: HookCollection, string_list: list[str], hook: Callable ) -> None: assert len(string_list) > 0 assert ( string_list[0] in hook_collection.all_hooks ), f"string_list: {string_list}, hook_collection: {hook_collection}" sub_hook_collection = hook_collection.all_hooks[string_list[0]] if len(string_list) == 1: assert isinstance( sub_hook_collection, Hooks ), f"string_list: {string_list}, hook_collection: {type(hook_collection)}, sub_hook_collection: {type(sub_hook_collection)}" sub_hook_collection.append(hook) else: assert isinstance( sub_hook_collection, HookCollection ), f"string_list: {string_list}, hook_collection: {type(hook_collection)}, sub_hook_collection: {type(sub_hook_collection)}" _append_to_hook_collection_using_string_list(sub_hook_collection, string_list[1:], hook) def _append_to_hook_collection_using_activation_location_type_and_pass_type( hook_collection: HookCollection, activation_location_type_and_pass_type: ActivationLocationTypeAndPassType, hook: Callable, append_to_fwd2: bool = False, ) -> None: activation_location_type = activation_location_type_and_pass_type.activation_location_type pass_type = activation_location_type_and_pass_type.pass_type standard_model_hook_location_type = get_hook_location_type_for_activation_location_type( activation_location_type, inference_engine_type=InferenceEngineType.STANDARD ) if ( "resid" in standard_model_hook_location_type and "post_emb" not in standard_model_hook_location_type and "ln_f" not in standard_model_hook_location_type and "post_ln_f" not in standard_model_hook_location_type ): # an extra "torso" is needed for the residual location types standard_model_hook_location_type = standard_model_hook_location_type.replace( "resid", "resid/torso" ) string_list = standard_model_hook_location_type.split("/") # e.g. ["mlp", "post_act"] if append_to_fwd2: assert pass_type == PassType.FORWARD string_list += ["fwd2"] # called after all "fwd" and "bwd" hooks else: string_list += [_pass_type_hc_name_by_hook_pass_type[pass_type]] _append_to_hook_collection_using_string_list(hook_collection, string_list, hook) _pass_type_hc_name_by_hook_pass_type: dict[PassType, str] = { PassType.FORWARD: "fwd", PassType.BACKWARD: "bwd", } class PerLayerHookCollection(HookCollection): """ Organizes HookCollections by layer; supports e.g. appending to the same location within each per-layer HookCollection by supplying a callable to apply_fn_to_all_layers, to do that appending. """ def __init__(self, hook_collection_by_layer: Mapping[LayerIndex, HookCollection]) -> None: super().__init__() for layer in hook_collection_by_layer.keys(): self.add_subhooks(layer, hook_collection_by_layer[layer]) def __call__(self, x: torch.Tensor, *, layer: LayerIndex = None, **kwargs: Any) -> torch.Tensor: if layer in self.all_hooks: return self.all_hooks[layer](x, layer=layer, **kwargs) else: return x def append_to_all_layers_using_string_list( self, string_list: list[str], hook: Callable ) -> None: for layer in self.all_hooks.keys(): _append_to_hook_collection_using_string_list(self.all_hooks[layer], string_list, hook) def __deepcopy__(self, memo: dict) -> "PerLayerHookCollection": # can't use deepcopy because of __getattr__ hook_collection_by_layer = self.all_hooks new = self.__class__(self.all_hooks) new.all_hooks = deepcopy(self.all_hooks) return new class HookGraph(ABC): """ This is a wrapper for HookCollection objects that supports 1. adding hooks at points specified using activation_location_type + pass_type and optionally layer_indices 2. adding subgraphs that are themselves HookGraphs, in such a way that activation_location_types within the subgraph remain accessible by the same activation_location_type + pass_type + layer_indices interface """ hook_collection: HookCollection activation_locations: set[ActivationLocationType] subgraph_by_name: dict[str, "InjectableHookGraph"] def __call__(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: return self.hook_collection(x, *args, **kwargs) # type: ignore def append( self, activation_location_type_and_pass_type: ActivationLocationTypeAndPassType, hook: Callable, layer_indices: int | list[int] | None = None, append_to_fwd2: bool = False, ) -> None: pass def inject_subgraph( self, # activation_location_type_and_pass_type: ActivationLocationTypeAndPassType, subgraph: "InjectableHookGraph", name: str, layer_indices: int | list[int] | None = None, ) -> None: activation_location_type_and_pass_type = subgraph.at_activation_location_type_and_pass_type activation_location_type = activation_location_type_and_pass_type.activation_location_type pass_type = activation_location_type_and_pass_type.pass_type assert ( activation_location_type in self.activation_locations ), f"{activation_location_type} not in {self.activation_locations}" assert name not in self.subgraph_by_name # assert no overlap between activation locations of self and graph assert not self.activation_locations.intersection(subgraph.activation_locations), ( self.activation_locations, subgraph.activation_locations, ) self.append( activation_location_type_and_pass_type=activation_location_type_and_pass_type, hook=subgraph, layer_indices=layer_indices, append_to_fwd2=True, # we inject the subgraph after the forward and backward hooks ) self.subgraph_by_name[name] = subgraph self.activation_locations = self.activation_locations.union(subgraph.activation_locations) class InjectableHookGraph(HookGraph): """ This is a HookGraph that can be injected into another HookGraph. It contains one extra piece of information: the activation_location_type_and_pass_type where it is to be injected. """ at_activation_location_type_and_pass_type: ActivationLocationTypeAndPassType class TransformerHookGraph(HookGraph): """ This is a HookGraph that specifically wraps TransformerHooks. It can be used with the Transformer.forward() function call using the transformer_graph.as_transformer_hooks() method. """ def __init__(self) -> None: self.hook_collection = TransformerHooks() self.subgraph_by_name: dict[str, InjectableHookGraph] = {} self.activation_locations = standard_model_activation_location_types def append( self, activation_location_type_and_pass_type: ActivationLocationTypeAndPassType, hook: Callable, layer_indices: int | list[int] | None = None, append_to_fwd2: bool = False, ) -> None: activation_location_type = activation_location_type_and_pass_type.activation_location_type pass_type = activation_location_type_and_pass_type.pass_type if layer_indices is not None: assert ( not activation_location_type.has_no_layers ), f"activation_location_type: {activation_location_type}, layer_indices: {layer_indices}" hook = AtLayers(layer_indices).append(hook) assert ( activation_location_type in self.activation_locations ), f"{activation_location_type} not in {self.activation_locations}" if activation_location_type in standard_model_activation_location_types: _append_to_hook_collection_using_activation_location_type_and_pass_type( self.hook_collection, activation_location_type_and_pass_type, hook, append_to_fwd2, ) else: for name in self.subgraph_by_name.keys(): if activation_location_type in self.subgraph_by_name[name].activation_locations: self.subgraph_by_name[name].append(activation_location_type_and_pass_type, hook) def as_transformer_hooks(self) -> TransformerHooks: return cast(TransformerHooks, self.hook_collection) class AutoencoderHookGraph(InjectableHookGraph): """ This is a HookGraph that specifically wraps a PerLayerHookCollection of AutoencoderHooks (in general, one per layer). """ def __init__( self, autoencoder_context: AutoencoderContext, is_one_of_multiple_autoencoders: bool = False ) -> None: autoencoder_hooks_by_layer_index: dict[LayerIndex, AutoencoderHooks] = {} layer_indices = autoencoder_context.layer_indices or [None] for layer_index in layer_indices: autoencoder = autoencoder_context.get_autoencoder(layer_index) autoencoder_hooks_by_layer_index[layer_index] = AutoencoderHooks( encode=unflatten(autoencoder.encode), decode=unflatten(autoencoder.decode), add_error=True, ) if not autoencoder_context.dst.is_raw_activation_type: raise NotImplementedError( "AutoencoderHookGraph only supports raw activation types for now." ) self.at_activation_location_type_and_pass_type = ActivationLocationTypeAndPassType( autoencoder_context.dst.to_activation_location_type(), PassType.FORWARD ) self.hook_collection = PerLayerHookCollection(autoencoder_hooks_by_layer_index) self.location_hc_name_by_activation_location_type = ( self.get_location_hc_name_by_activation_location_type( autoencoder_context.dst, is_one_of_multiple_autoencoders ) ) self.activation_locations = set(self.location_hc_name_by_activation_location_type.keys()) self.autoencoder_context = autoencoder_context def append( self, activation_location_type_and_pass_type: ActivationLocationTypeAndPassType, hook: Callable, layer_indices: int | list[int] | None = None, append_to_fwd2: bool = False, ) -> None: activation_location_type = activation_location_type_and_pass_type.activation_location_type pass_type = activation_location_type_and_pass_type.pass_type if layer_indices is not None: assert ( not activation_location_type.has_no_layers ), f"activation_location_type: {activation_location_type}, layer_indices: {layer_indices}" hook = AtLayers(layer_indices).append(hook) assert ( activation_location_type in self.activation_locations ), f"{activation_location_type} not in {self.activation_locations}" assert pass_type in _pass_type_hc_name_by_hook_pass_type string_list = [ self.location_hc_name_by_activation_location_type[activation_location_type], _pass_type_hc_name_by_hook_pass_type[pass_type], ] self.hook_collection.append_to_all_layers_using_string_list(string_list, hook) # note: hc = hook_collection def get_location_hc_name_by_activation_location_type( self, dst: DerivedScalarType, is_one_of_multiple_autoencoders: bool ) -> dict[ActivationLocationType, str]: latent_alt_by_dst = { DerivedScalarType.MLP_POST_ACT: ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT, DerivedScalarType.RESID_DELTA_MLP: ActivationLocationType.ONLINE_MLP_AUTOENCODER_LATENT, DerivedScalarType.RESID_DELTA_ATTN: ActivationLocationType.ONLINE_ATTENTION_AUTOENCODER_LATENT, } error_alt_by_dst = { DerivedScalarType.MLP_POST_ACT: ActivationLocationType.ONLINE_MLP_AUTOENCODER_ERROR, DerivedScalarType.RESID_DELTA_MLP: ActivationLocationType.ONLINE_RESIDUAL_MLP_AUTOENCODER_ERROR, DerivedScalarType.RESID_DELTA_ATTN: ActivationLocationType.ONLINE_RESIDUAL_ATTENTION_AUTOENCODER_ERROR, } location_hc_name_by_alt = { latent_alt_by_dst[dst]: "latents", error_alt_by_dst[dst]: "error", } if not is_one_of_multiple_autoencoders: # if there is only one autoencoder, we also add the "ONLINE_AUTOENCODER_LATENT" location for backward compatibility generic_latent_alt = ActivationLocationType.ONLINE_AUTOENCODER_LATENT location_hc_name_by_alt[generic_latent_alt] = "latents" return location_hc_name_by_alt