neuron_explainer/explanations/explanations.py (136 lines of code) (raw):

# Dataclasses and enums for storing neuron explanations, their scores, and related data. Also, # related helper functions. from __future__ import annotations import json import math import os.path as osp from dataclasses import dataclass from enum import Enum from typing import Any from neuron_explainer.activations.activations import NeuronId from neuron_explainer.fast_dataclasses import FastDataclass, loads, register_dataclass from neuron_explainer.file_utils import CustomFileHandler, file_exists, read_single_async class ActivationScale(str, Enum): """Which "units" are stored in the expected_activations/distribution_values fields of a SequenceSimulation. This enum identifies whether the values represent real activations of the neuron or something else. Different scales are not necessarily related by a linear transformation. """ NEURON_ACTIVATIONS = "neuron_activations" """Values represent real activations of the neuron.""" SIMULATED_NORMALIZED_ACTIVATIONS = "simulated_normalized_activations" """ Values represent simulated activations of the neuron, normalized to the range [0, 10]. This scale is arbitrary and should not be interpreted as a neuron activation. """ HUMAN_PREDICTED_ACTIVATIONS = "human_predicted_activations" """ Values represent human predictions of the neuron's activation, normalized to the range [0, 2]: 0=not active 1=weakly/possibly active 2=strongly/definitely active Not used at present. """ @register_dataclass @dataclass class SequenceSimulation(FastDataclass): """The result of a simulation of neuron activations on one text sequence.""" tokens: list[str] """The sequence of tokens that was simulated.""" expected_activations: list[float] """Expected value of the possibly-normalized activation for each token in the sequence.""" activation_scale: ActivationScale """What scale is used for values in the expected_activations field.""" distribution_values: list[list[float]] """ For each token in the sequence, a list of values from the discrete distribution of activations produced from simulation. Tokens will be included here if and only if they are in the top K=15 tokens predicted by the simulator, and excluded otherwise. May be transformed to another unit by calibration. When we simulate a neuron, we produce a discrete distribution with values in the arbitrary discretized space of the neuron, e.g. 10% chance of 0, 70% chance of 1, 20% chance of 2. Which we store as distribution_values = [0, 1, 2], distribution_probabilities = [0.1, 0.7, 0.2]. When we tranform the distribution to the real activation units, we can correspondingly tranform the values of this distribution to get a distribution in the units of the neuron. e.g. if the mapping from the discretized space to the real activation unit of the neuron is f(x) = x/2, then the distribution becomes 10% chance of 0, 70% chance of 0.5, 20% chance of 1. Which we store as distribution_values = [0, 0.5, 1], distribution_probabilities = [0.1, 0.7, 0.2]. """ distribution_probabilities: list[list[float]] """ For each token in the sequence, the probability of the corresponding value in distribution_values. """ uncalibrated_simulation: "SequenceSimulation" | None = None """The result of the simulation before calibration.""" SequenceSimulation.field_renamed("unit", "activation_scale") SequenceSimulation.field_deleted("response") @register_dataclass @dataclass class ScoredSequenceSimulation(FastDataclass): """ SequenceSimulation result with a score (for that sequence only) and ground truth activations. """ sequence_simulation: SequenceSimulation """The result of a simulation of neuron activations.""" true_activations: list[float] """Ground truth activations on the sequence (not normalized)""" ev_correlation_score: float """ Correlation coefficient between the expected values of the normalized activations from the simulation and the unnormalized true activations of the neuron on the text sequence. """ rsquared_score: float | None = None """R^2 of the simulated activations.""" absolute_dev_explained_score: float | None = None """ Score based on absolute difference between real and simulated activations. absolute_dev_explained_score = 1 - mean(abs(real-predicted))/ mean(abs(real)) """ def __eq__(self, other: Any) -> bool: if not isinstance(other, ScoredSequenceSimulation): return False if len(self.__dict__.keys()) != len(other.__dict__.keys()): return False # Since NaN != NaN in Python, we need to make an exception for this case when checking for equality # of two ScoredSequenceSimulation objects. for field_name in self.__dict__.keys(): if field_name not in other.__dict__: return False self_val, other_val = self.__dict__[field_name], other.__dict__[field_name] if self_val != other_val: if not ( isinstance(self_val, float) and math.isnan(self_val) and isinstance(other_val, float) and math.isnan(other_val) ): return False return True ScoredSequenceSimulation.field_renamed("simulation", "sequence_simulation") @register_dataclass @dataclass class ScoredSimulation(FastDataclass): """Result of scoring a neuron simulation on multiple sequences.""" scored_sequence_simulations: list[ScoredSequenceSimulation] """ScoredSequenceSimulation for each sequence""" ev_correlation_score: float | None = None """ Correlation coefficient between the expected values of the normalized activations from the simulation and the unnormalized true activations on a dataset created from all score_results. (Note that this is not equivalent to averaging across sequences.) """ rsquared_score: float | None = None """R^2 of the simulated activations.""" absolute_dev_explained_score: float | None = None """ Score based on absolute difference between real and simulated activations. absolute_dev_explained_score = 1 - mean(abs(real-predicted))/ mean(abs(real)). """ def get_preferred_score(self) -> float | None: """ This method may return None in cases where the score is undefined, for example if the normalized activations were all zero, yielding a correlation coefficient of NaN. """ return self.ev_correlation_score @register_dataclass @dataclass class ScoredExplanation(FastDataclass): """Simulator parameters and the results of scoring it on multiple sequences""" explanation: str scored_simulation: ScoredSimulation """Result of scoring the neuron simulator on multiple sequences.""" def get_preferred_score(self) -> float | None: """ This method may return None in cases where the score is undefined, for example if the normalized activations were all zero, yielding a correlation coefficient of NaN. """ return self.scored_simulation.get_preferred_score() ScoredExplanation.was_previously_named("ScoredExplanationOrBaseline") ScoredExplanation.field_renamed("explanation_or_baseline", "explanation") @register_dataclass @dataclass class NeuronSimulationResults(FastDataclass): """Simulation results and scores for a neuron.""" neuron_id: NeuronId scored_explanations: list[ScoredExplanation] NeuronSimulationResults.field_renamed("scored_explanation_or_baseline_list", "scored_explanations") @register_dataclass @dataclass class AttentionSimulation(FastDataclass): tokens: list[str] token_pair_coords: tuple[int, int] """The coordinates of the token pair that we're simulating attention for.""" token_pair_label: int """Either 0 or 1 for negative or positive label, respectively.""" simulation_prediction: float """The predicted label for the token pair from the attention simulator.""" @register_dataclass @dataclass class ScoredAttentionSimulation(FastDataclass): """Result of scoring an attention head simulation on multiple sequences.""" attention_simulations: list[AttentionSimulation] """ScoredSequenceSimulation for each sequence""" roc_auc_score: float | None = None """ Area under the ROC curve for the attention predictions. Each AttentionSimulation is essentially a single binary classification. """ def get_preferred_score(self) -> float | None: return self.roc_auc_score @register_dataclass @dataclass class ScoredAttentionExplanation(FastDataclass): """Simulator parameters and the results of scoring it on multiple sequences""" explanation: str scored_attention_simulation: ScoredAttentionSimulation """Result of scoring the neuron simulator on multiple sequences.""" def get_preferred_score(self) -> float | None: """ This method may return None in cases where the score is undefined, for example if the normalized activations were all zero, yielding a correlation coefficient of NaN. """ return self.scored_attention_simulation.get_preferred_score() @register_dataclass @dataclass class AttentionSimulationResults(FastDataclass): """Simulation results and scores for an attention head.""" # Typing this as NeuronId is not ideal but I'm not sure if we want to rename the type to something more general. attention_head_id: NeuronId scored_explanations: list[ScoredAttentionExplanation] AttentionSimulationResults.field_renamed("attention_id", "attention_head_id") def load_neuron_explanations( explanations_path: str, layer_index: str | int, neuron_index: str | int ) -> NeuronSimulationResults | None: """Load scored explanations for the specified neuron.""" file = osp.join(explanations_path, str(layer_index), f"{neuron_index}.jsonl") if not file_exists(file): return None with CustomFileHandler(file) as f: for line in f: return loads(line) return None async def load_neuron_explanations_async( explanations_path: str, layer_index: str | int, neuron_index: str | int ) -> NeuronSimulationResults | None: """Load scored explanations for the specified neuron, asynchronously.""" return await read_explanation_file( osp.join(explanations_path, str(layer_index), f"{neuron_index}.jsonl") ) async def read_file(filename: str) -> str | None: """Read the contents of the given file as a string, asynchronously. File can be a local file or a remote file.""" try: raw_contents = await read_single_async(filename) except FileNotFoundError: return None lines = [] for line in raw_contents.decode("utf-8").split("\n"): if len(line) > 0: lines.append(line) assert len(lines) == 1, filename return lines[0] async def read_explanation_file(explanation_filename: str) -> NeuronSimulationResults | None: """Load scored explanations from the given filename, asynchronously.""" line = await read_file(explanation_filename) return loads(line) if line is not None else None async def read_json_file(filename: str) -> dict | None: """Read the contents of the given file as a JSON object, asynchronously.""" line = await read_file(filename) return json.loads(line) if line is not None else None