neuron_explainer/activation_server/dst_helpers.py (75 lines of code) (raw):

# Small helper functions for working with derived scalars in the context of activation server # request handling. import math from typing import Any, Callable, TypeVar import torch from neuron_explainer.activation_server.requests_and_responses import * from neuron_explainer.activations.derived_scalars.derived_scalar_store import DerivedScalarStore from neuron_explainer.activations.derived_scalars.derived_scalar_types import DerivedScalarType from neuron_explainer.activations.derived_scalars.indexing import ( DerivedScalarIndex, MirroredNodeIndex, ) from neuron_explainer.models.model_component_registry import Dimension T = TypeVar("T") def _float_tensor_to_list(x: torch.Tensor) -> list[float]: return [x if math.isfinite(x) else -999 for x in x.tolist()] def _torch_to_tensor_nd(x: torch.Tensor) -> TensorND: ndim = x.ndim if ndim == 0: return Tensor0D(value=x.item()) elif ndim == 1: return Tensor1D(value=_float_tensor_to_list(x)) elif ndim == 2: return Tensor2D(value=[_float_tensor_to_list(row) for row in x]) elif ndim == 3: return Tensor3D(value=[[_float_tensor_to_list(row) for row in matrix] for matrix in x]) else: raise NotImplementedError(f"Unknown ndim: {ndim}") def _get_dims_to_keep( dst: DerivedScalarType, keep_dimension_fn: Callable[[Dimension], bool] ) -> list[Dimension]: return [dim for dim in dst.shape_spec_per_token_sequence if keep_dimension_fn(dim)] def _sum_dst( ds_store: DerivedScalarStore, dst: DerivedScalarType, keep_dimension_fn: Callable[[Dimension], bool], abs_mode: bool, ) -> torch.Tensor: dims_to_keep = _get_dims_to_keep(dst, keep_dimension_fn) store_for_dst = ds_store.filter_dsts([dst]) activations_and_metadata = next( iter(store_for_dst.activations_and_metadata_by_dst_and_pass_type.values()) ) ndim_before_sum = len(activations_and_metadata.shape) if abs_mode: sum_for_dst = store_for_dst.sum_abs(dims_to_keep=dims_to_keep) else: sum_for_dst = store_for_dst.sum(dims_to_keep=dims_to_keep) assert len(sum_for_dst.shape) == len( dims_to_keep ), f"{sum_for_dst.shape=}, {ndim_before_sum=}, {dims_to_keep=}" return sum_for_dst def get_intermediate_sum_by_dst( ds_store: DerivedScalarStore, keep_dimension_fn: Callable[[Dimension], bool], abs_mode: bool = False, ) -> dict[DerivedScalarType, TensorND]: dict_of_torch_tensors = { dst: _sum_dst(ds_store, dst, keep_dimension_fn, abs_mode=abs_mode) for dst in ds_store.dsts } return {dst: _torch_to_tensor_nd(x) for dst, x in dict_of_torch_tensors.items()} def get_ds_index_from_node_index( node_index: MirroredNodeIndex, dsts: list[DerivedScalarType], ) -> DerivedScalarIndex: """ Converts from a MirroredNodeIndex (more general, e.g. defined by a NodeType such as MLP neurons) to a DerivedScalarIndex (more specific, e.g. defined by a DerivedScalarType such as MLP write norm) conditional on the given derived scalar types, which are assumed to be unique for each NodeType. """ dsts_matching_node_type = [dst for dst in dsts if dst.node_type == node_index.node_type] assert len(dsts_matching_node_type) == 1, ( f"Expected exactly one derived scalar type to have node type {node_index.node_type}, " f"but found {dsts_matching_node_type} in {dsts}" ) return DerivedScalarIndex.from_node_index( node_index=node_index, dst=dsts_matching_node_type[0], ) def assert_tensor(tensor: Any) -> torch.Tensor: # for mypy assert isinstance(tensor, torch.Tensor) return tensor