neuron_explainer/activation_server/explainer_routes.py (410 lines of code) (raw):

"""Routes / endpoints related to generating and scoring explanations.""" from __future__ import annotations import os import os.path as osp from enum import Enum, unique from typing import Any, TypeVar from fastapi import FastAPI, HTTPException from neuron_explainer.activation_server.explanation_datasets import ( AZURE_EXPLANATION_DATASET_REGISTRY, get_local_cached_explanation_directory, ) from neuron_explainer.activation_server.load_neurons import load_neuron_from_datasets from neuron_explainer.activation_server.read_routes import NodeIdAndDatasets from neuron_explainer.activations.activations import ( ActivationRecordSliceParams, NeuronId, NeuronRecord, ) from neuron_explainer.activations.derived_scalars import DerivedScalarType from neuron_explainer.api_client import ApiClient from neuron_explainer.explanations.attention_head_scoring import AttentionHeadOneAtATimeScorer from neuron_explainer.explanations.explainer import ( AttentionHeadExplainer, NeuronExplainer, TokenActivationPairExplainer, ) from neuron_explainer.explanations.explanations import ( AttentionSimulationResults, NeuronSimulationResults, ScoredAttentionExplanation, ScoredExplanation, ) from neuron_explainer.explanations.prompt_builder import PromptFormat from neuron_explainer.explanations.scoring import ( make_simulator_and_score, make_uncalibrated_explanation_simulator, ) from neuron_explainer.fast_dataclasses.fast_dataclasses import dumps, loads from neuron_explainer.file_utils import file_exists, read_single_async from neuron_explainer.models.model_component_registry import NodeType from neuron_explainer.pydantic import CamelCaseBaseModel, immutable T = TypeVar("T", bound="BaseMethodId") @unique class BaseMethodId(str, Enum): @classmethod def from_string(cls: type[T], s: str) -> T: for method_id in cls: if method_id.value == s: return method_id raise ValueError(f"{s} is not a valid {cls.__name__}") @unique class NeuronExplainAndScoreMethodId(BaseMethodId): BASELINE = "baseline" _NEURON_EXPLAINER_REGISTRY: dict[NeuronExplainAndScoreMethodId, NeuronExplainer] = { NeuronExplainAndScoreMethodId.BASELINE: TokenActivationPairExplainer( model_name="gpt-4o", cache=True, prompt_format=PromptFormat.CHAT_MESSAGES, ), } @unique class AttentionExplainAndScoreMethodId(BaseMethodId): BASELINE = "baseline" # Maybe in the future will split this into one for the explainer and one for the scorer _ATTENTION_EXPLAINER_REGISTRY: dict[ AttentionExplainAndScoreMethodId, tuple[AttentionHeadExplainer, AttentionHeadOneAtATimeScorer] ] = { AttentionExplainAndScoreMethodId.BASELINE: ( AttentionHeadExplainer( model_name="gpt-4o", prompt_format=PromptFormat.CHAT_MESSAGES, repeat_strongly_attending_pairs=True, ), AttentionHeadOneAtATimeScorer( model_name="gpt-4o", prompt_format=PromptFormat.CHAT_MESSAGES, ), ) } @immutable class ExplanationResult(CamelCaseBaseModel): explanations: list[str] # TODO(sbills): Get consistent about "dataset" vs "dataset_path". dataset: str @immutable class ScoreRequest(NodeIdAndDatasets): explanation: str max_sequences: int | None = None @immutable class ScoreResult(CamelCaseBaseModel): score: float dataset_path: str @unique class ActivationCategory(str, Enum): NEURON = "neuron" ATTENTION_HEAD = "attention_head" def define_explainer_routes( app: FastAPI, neuron_method_id: NeuronExplainAndScoreMethodId, attention_head_method_id: AttentionExplainAndScoreMethodId, ) -> None: simulation_client = ApiClient( model_name="gpt-4o", max_concurrent=5, cache=True, ) neuron_explainer = _NEURON_EXPLAINER_REGISTRY[neuron_method_id] attention_head_explainer, attention_head_scorer = _ATTENTION_EXPLAINER_REGISTRY[ attention_head_method_id ] def _map_dst_to_activation_category( dst: DerivedScalarType, ) -> ActivationCategory: if dst in [ DerivedScalarType.MLP_POST_ACT, DerivedScalarType.MLP_AUTOENCODER_LATENT, DerivedScalarType.ONLINE_MLP_AUTOENCODER_LATENT, DerivedScalarType.AUTOENCODER_LATENT, DerivedScalarType.ONLINE_AUTOENCODER_LATENT, ]: return ActivationCategory.NEURON elif dst in [ DerivedScalarType.ATTN_WRITE_NORM, DerivedScalarType.ATTENTION_AUTOENCODER_LATENT, DerivedScalarType.ONLINE_ATTENTION_AUTOENCODER_LATENT, DerivedScalarType.FLATTENED_ATTN_WRITE_TO_LATENT_SUMMED_OVER_HEADS, ]: return ActivationCategory.ATTENTION_HEAD else: raise HTTPException(status_code=422, detail=f"Unsupported derived scalar type {dst}") def _get_azure_explanation_path(request: NodeIdAndDatasets, dataset_path: str) -> str | None: if dataset_path in AZURE_EXPLANATION_DATASET_REGISTRY: expl_dir = AZURE_EXPLANATION_DATASET_REGISTRY[dataset_path] return osp.join(expl_dir, str(request.layer_index), f"{request.activation_index}.jsonl") return None def _get_local_cached_explanation_path(request: NodeIdAndDatasets, dataset_path: str) -> str: if request.dst.node_type == NodeType.ATTENTION_HEAD: method_id_str = str(attention_head_method_id) else: method_id_str = str(neuron_method_id) cache_dir = get_local_cached_explanation_directory(dataset_path) return osp.join( cache_dir, f"cache_{request.dst}_{method_id_str}", str(request.layer_index), f"{request.activation_index}.jsonl", ) def _verify_cached_simulation_results( request: NodeIdAndDatasets, simulation_results: Any, ) -> None: """Verifies the type and id of the cached simulation results.""" if not isinstance(simulation_results, NeuronSimulationResults) and not isinstance( simulation_results, AttentionSimulationResults ): raise HTTPException( status_code=422, detail=f"Unexpected type {type(simulation_results)} in cache" ) elem_id = ( simulation_results.neuron_id if isinstance(simulation_results, NeuronSimulationResults) else simulation_results.attention_head_id ) if ( elem_id.layer_index != request.layer_index or elem_id.neuron_index != request.activation_index ): raise HTTPException( status_code=422, detail=f"Cache id mismatch: requested ({request.layer_index}, {request.activation_index}, cache contained ({elem_id.layer_index}, {elem_id.layer_index})", ) def _merge_simulation_results( azure_simulation_results: NeuronSimulationResults | AttentionSimulationResults | None, local_simulation_results: NeuronSimulationResults | AttentionSimulationResults | None, ) -> NeuronSimulationResults | AttentionSimulationResults | None: """Merge scored explanations from the local cache and azure into a single NeuronSimulationResults or AttentionSimulationResults object.""" if azure_simulation_results is None and local_simulation_results is None: return None if isinstance(azure_simulation_results, NeuronSimulationResults) or isinstance( local_simulation_results, NeuronSimulationResults ): assert ( isinstance(azure_simulation_results, NeuronSimulationResults) or azure_simulation_results is None ) assert ( isinstance(local_simulation_results, NeuronSimulationResults) or local_simulation_results is None ) unique_scored_explanations = {} if azure_simulation_results is not None: for scored_explanation in azure_simulation_results.scored_explanations: unique_scored_explanations[scored_explanation.explanation] = scored_explanation if local_simulation_results is not None: for scored_explanation in local_simulation_results.scored_explanations: unique_scored_explanations[scored_explanation.explanation] = scored_explanation return NeuronSimulationResults( neuron_id=( azure_simulation_results.neuron_id if azure_simulation_results is not None else local_simulation_results.neuron_id # type: ignore # mypy doesn't understand that both can't be None ), scored_explanations=list(unique_scored_explanations.values()), ) else: assert ( isinstance(azure_simulation_results, AttentionSimulationResults) or azure_simulation_results is None ) assert ( isinstance(local_simulation_results, AttentionSimulationResults) or local_simulation_results is None ) unique_scored_attn_explanations = {} if azure_simulation_results is not None: for scored_attn_explanation in azure_simulation_results.scored_explanations: unique_scored_attn_explanations[ scored_attn_explanation.explanation ] = scored_attn_explanation if local_simulation_results is not None: for scored_attn_explanation in local_simulation_results.scored_explanations: unique_scored_attn_explanations[ scored_attn_explanation.explanation ] = scored_attn_explanation return AttentionSimulationResults( attention_head_id=( azure_simulation_results.attention_head_id if azure_simulation_results is not None else local_simulation_results.attention_head_id # type: ignore # mypy doesn't understand that both can't be None ), scored_explanations=list(unique_scored_attn_explanations.values()), ) async def _check_disk_for_simulation_results( request: NodeIdAndDatasets, dataset_path: str ) -> NeuronSimulationResults | AttentionSimulationResults | None: """ If the request is for scored explanations in one of the public sets on azure, return them if any exist. Include any scored explanations in the local cache. If there are no explanations in azure or the local cache, return None. """ azure_path = _get_azure_explanation_path(request, dataset_path) cache_path = _get_local_cached_explanation_path(request, dataset_path) azure_simulation_results, local_simulation_results = None, None if azure_path is not None and file_exists(azure_path): azure_simulation_results = loads( await read_single_async(azure_path), backwards_compatible=False ) _verify_cached_simulation_results(request, azure_simulation_results) if file_exists(cache_path): local_simulation_results = loads( await read_single_async(cache_path), backwards_compatible=False ) _verify_cached_simulation_results(request, local_simulation_results) # Merge the results from azure and the local cache, deduplicating any scored explanations. # Thus we have a single object that contains all the scored explanations for the node. combined_simulation_results = _merge_simulation_results( azure_simulation_results, local_simulation_results ) return combined_simulation_results async def _explain_neuron(neuron_record: NeuronRecord) -> list[str]: if neuron_record.max_activation < 0: raise HTTPException(status_code=422, detail="Neuron is not activated on the dataset") train_activation_records = neuron_record.train_activation_records( activation_record_slice_params=ActivationRecordSliceParams(n_examples_per_split=5) ) return await neuron_explainer.generate_explanations( all_activations=train_activation_records, max_activation=neuron_record.max_activation, ) # NeuronRecord contains attention head activation info (it's an outdated name) async def _explain_attention_head(attention_record: NeuronRecord) -> list[str]: train_activation_records = attention_record.train_activation_records( activation_record_slice_params=ActivationRecordSliceParams(n_examples_per_split=5) ) return await attention_head_explainer.generate_explanations( all_activations=train_activation_records, max_tokens=50, num_top_pairs_to_display=5, ) @app.post("/explain", response_model=ExplanationResult, tags=["explainer"]) async def explain(request: NodeIdAndDatasets) -> ExplanationResult: dataset_path, neuron_record = await load_neuron_from_datasets(request) cached_simulation_results = await _check_disk_for_simulation_results(request, dataset_path) if cached_simulation_results is not None: explanations = [s.explanation for s in cached_simulation_results.scored_explanations] else: activation_category = _map_dst_to_activation_category(request.dst) if activation_category == ActivationCategory.ATTENTION_HEAD: explanations = await _explain_attention_head(neuron_record) elif activation_category == ActivationCategory.NEURON: explanations = await _explain_neuron(neuron_record) else: raise HTTPException( status_code=422, detail=f"Unsupported activation category for explanation: {activation_category}", ) return ExplanationResult(explanations=explanations, dataset=dataset_path) async def _score_neuron( cached_simulation_results: NeuronSimulationResults | None, neuron: NeuronRecord, request: ScoreRequest, max_sequences: int | None = None, ) -> tuple[float | None, NeuronSimulationResults]: """Score an explanation for a neuron. Add it to the cached set of simulation results, or create the simulation results object if the cache was empty.""" if neuron.max_activation < 0: raise HTTPException(status_code=422, detail="Neuron is not activated on the dataset") valid_activation_records = neuron.valid_activation_records( activation_record_slice_params=ActivationRecordSliceParams(n_examples_per_split=5) ) if max_sequences is not None: valid_activation_records = valid_activation_records[:max_sequences] scored_simulation = await make_simulator_and_score( make_uncalibrated_explanation_simulator( request.explanation, simulation_client, prompt_format=PromptFormat.CHAT_MESSAGES, ), valid_activation_records, ) scored_explanation = ScoredExplanation( explanation=request.explanation, scored_simulation=scored_simulation, ) if cached_simulation_results is not None: cached_simulation_results.scored_explanations.append(scored_explanation) else: cached_simulation_results = NeuronSimulationResults( neuron_id=NeuronId( neuron_index=request.activation_index, layer_index=request.layer_index, ), scored_explanations=[scored_explanation], ) return scored_explanation.get_preferred_score(), cached_simulation_results async def _score_attention_head( cached_simulation_results: AttentionSimulationResults | None, attention_record: NeuronRecord, request: ScoreRequest, max_sequences: int | None = None, ) -> tuple[float | None, AttentionSimulationResults]: """Score an explanation for an attention head. Add it to the cached set of simulation results, or create the simulation results object if the cache was empty.""" if attention_record.max_activation < 0: raise HTTPException( status_code=422, detail="Attention head is not activated on the dataset" ) valid_activation_records = attention_record.valid_activation_records( activation_record_slice_params=ActivationRecordSliceParams(n_examples_per_split=5) ) if max_sequences is not None: valid_activation_records = valid_activation_records[:max_sequences] scored_attention_simulation = await attention_head_scorer.score_explanation( activation_records=valid_activation_records, explanation=request.explanation, max_activation=attention_record.max_activation, ) scored_explanation = ScoredAttentionExplanation( explanation=request.explanation, scored_attention_simulation=scored_attention_simulation, ) if cached_simulation_results is not None: cached_simulation_results.scored_explanations.append(scored_explanation) else: cached_simulation_results = AttentionSimulationResults( attention_head_id=NeuronId( neuron_index=request.activation_index, layer_index=request.layer_index, ), scored_explanations=[scored_explanation], ) return scored_explanation.get_preferred_score(), cached_simulation_results def _cache_simulation_results_locally( request: ScoreRequest, dataset_path: str, cached_simulation_results: NeuronSimulationResults | AttentionSimulationResults, ) -> None: # Overwrite the cache with the updated set of simulation results. # Always cache locally because we can't write to the public azure bucket. cache_path = _get_local_cached_explanation_path(request, dataset_path) # Create the directories if they don't exist os.makedirs(os.path.dirname(cache_path), exist_ok=True) with open(cache_path, "wb") as f: f.write(dumps(cached_simulation_results)) @app.post("/score", response_model=ScoreResult, tags=["explainer"]) async def score(request: ScoreRequest) -> ScoreResult: dataset_path, neuron_record = await load_neuron_from_datasets(request) cached_simulation_results = await _check_disk_for_simulation_results(request, dataset_path) # Cache hit: return the score for the matching explanation. if cached_simulation_results is not None and any( s.explanation == request.explanation for s in cached_simulation_results.scored_explanations ): score = [ s.get_preferred_score() for s in cached_simulation_results.scored_explanations if s.explanation == request.explanation ][0] if score is None: raise HTTPException(status_code=500, detail="Score is unexpectedly undefined") return ScoreResult(score=score, dataset_path=dataset_path) # Cache miss: compute the score for the requested explanation. activation_category = _map_dst_to_activation_category(request.dst) if activation_category == ActivationCategory.ATTENTION_HEAD: assert cached_simulation_results is None or isinstance( cached_simulation_results, AttentionSimulationResults ) # Score and update the cache storage object. score, cached_simulation_results = await _score_attention_head( cached_simulation_results, neuron_record, request, max_sequences=request.max_sequences, ) elif activation_category == ActivationCategory.NEURON: assert cached_simulation_results is None or isinstance( cached_simulation_results, NeuronSimulationResults ) # Score and update the cache storage object. score, cached_simulation_results = await _score_neuron( cached_simulation_results, neuron_record, request, max_sequences=request.max_sequences, ) else: raise HTTPException( status_code=422, detail=f"Unsupported activation category for scoring: {activation_category}", ) if score is None: raise HTTPException(status_code=500, detail="Score is unexpectedly undefined") _cache_simulation_results_locally(request, dataset_path, cached_simulation_results) return ScoreResult(score=score, dataset_path=dataset_path)