neuron_explainer/activations/derived_scalars/logprobs.py (94 lines of code) (raw):

from typing import Callable import torch from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import ActivationIndex, NodeIndex from neuron_explainer.activations.derived_scalars.reconstituted import ( make_apply_logits, make_apply_logprobs, ) from neuron_explainer.activations.derived_scalars.reconstituter_class import Reconstituter from neuron_explainer.models.model_component_registry import LayerIndex, NodeType, PassType from neuron_explainer.models.model_context import ModelContext class LogProbReconstituter(Reconstituter): """Reconstitute vocab token logprobs from final residual stream location. Can be used e.g. to compute effect of residual stream writes on token logprobs, rather than logits.""" residual_dst: DerivedScalarType = DerivedScalarType.RESID_POST_MLP requires_other_scalar_source: bool = False def __init__( self, model_context: ModelContext, detach_layer_norm_scale: bool, ): super().__init__() self._model_context = model_context self.detach_layer_norm_scale = detach_layer_norm_scale transformer = self._model_context.get_or_create_model() self._reconstitute_activations_fn = make_apply_logprobs( transformer=transformer, detach_layer_norm_scale=self.detach_layer_norm_scale, ) def reconstitute_activations( self, resid: torch.Tensor, other_arg: torch.Tensor | None, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert other_arg is None assert layer_index == self._model_context.n_layers - 1 assert pass_type == PassType.FORWARD return self._reconstitute_activations_fn(resid) class LogitReconstituter(Reconstituter): """Reconstitute vocab token logprobs from final residual stream location. Can be used e.g. to compute effect of residual stream writes on token logprobs, rather than logits.""" residual_dst: DerivedScalarType = DerivedScalarType.RESID_POST_MLP requires_other_scalar_source: bool = False def __init__( self, model_context: ModelContext, detach_layer_norm_scale: bool, ): super().__init__() self._model_context = model_context self.detach_layer_norm_scale = detach_layer_norm_scale transformer = self._model_context.get_or_create_model() self._reconstitute_activations_fn = make_apply_logits( transformer=transformer, detach_layer_norm_scale=self.detach_layer_norm_scale, ) def reconstitute_activations( self, resid: torch.Tensor, other_arg: torch.Tensor | None, layer_index: LayerIndex, pass_type: PassType, ) -> torch.Tensor: assert other_arg is None assert layer_index == self._model_context.n_layers - 1 assert pass_type == PassType.FORWARD return self._reconstitute_activations_fn(resid) def get_residual_activation_index(self) -> ActivationIndex: # this contains only the information that we're interested in the final residual stream layer dummy_node_index = NodeIndex( node_type=NodeType.LAYER, layer_index=self._model_context.n_layers - 1, tensor_indices=(), pass_type=PassType.FORWARD, ) return self.get_residual_activation_index_for_node_index( node_index=dummy_node_index, ) def make_reconstitute_gradient_of_loss_fn( self, loss_fn: Callable[[torch.Tensor], torch.Tensor], ) -> Callable[[torch.Tensor], torch.Tensor]: def scalar_hook( resid: torch.Tensor, ) -> torch.Tensor: # loss fn expects a batch dimension return loss_fn(resid.unsqueeze(0)).squeeze(0) def reconstitute_gradient( resid: torch.Tensor, ) -> torch.Tensor: return self.reconstitute_gradient( resid=resid, other_arg=None, layer_index=self._model_context.n_layers - 1, pass_type=PassType.FORWARD, scalar_hook=scalar_hook, ) return reconstitute_gradient