neuron-explainer/neuron_explainer/explanations/calibrated_simulator.py (116 lines of code) (raw):
"""
Code for calibrating simulations of neuron behavior. Calibration refers to a process of mapping from
a space of predicted activation values (e.g. [0, 10]) to the real activation distribution for a
neuron.
See http://go/neuron_explanation_methodology for description of calibration step. Necessary for
simulating neurons in the context of ablate-to-simulation, but can be skipped when using correlation
scoring. (Calibration may still improve quality for scoring, at least for non-linear calibration
methods.)
"""
from __future__ import annotations
import asyncio
from abc import abstractmethod
from typing import Optional, Sequence
import numpy as np
from neuron_explainer.activations.activations import ActivationRecord
from neuron_explainer.explanations.explanations import ActivationScale
from neuron_explainer.explanations.simulator import NeuronSimulator, SequenceSimulation
from sklearn import linear_model
class CalibratedNeuronSimulator(NeuronSimulator):
"""
Wrap a NeuronSimulator and calibrate it to map from the predicted activation space to the
actual neuron activation space.
"""
def __init__(self, uncalibrated_simulator: NeuronSimulator):
self.uncalibrated_simulator = uncalibrated_simulator
@classmethod
async def create(
cls,
uncalibrated_simulator: NeuronSimulator,
calibration_activation_records: Sequence[ActivationRecord],
) -> CalibratedNeuronSimulator:
"""
Create and calibrate a calibrated simulator (so initialization and calibration can be done
in one call).
"""
calibrated_simulator = cls(uncalibrated_simulator)
await calibrated_simulator.calibrate(calibration_activation_records)
return calibrated_simulator
async def calibrate(self, calibration_activation_records: Sequence[ActivationRecord]) -> None:
"""
Determine parameters to map from the predicted activation space to the real neuron
activation space, based on a calibration set.
Use when simulated sequences haven't already been produced on the calibration set.
"""
simulations = await asyncio.gather(
*[
self.uncalibrated_simulator.simulate(activations.tokens)
for activations in calibration_activation_records
]
)
self.calibrate_from_simulations(calibration_activation_records, simulations)
def calibrate_from_simulations(
self,
calibration_activation_records: Sequence[ActivationRecord],
simulations: Sequence[SequenceSimulation],
) -> None:
"""
Determine parameters to map from the predicted activation space to the real neuron
activation space, based on a calibration set.
Use when simulated sequences have already been produced on the calibration set.
"""
flattened_activations = []
flattened_simulated_activations: list[float] = []
for activations, simulation in zip(calibration_activation_records, simulations):
flattened_activations.extend(activations.activations)
flattened_simulated_activations.extend(simulation.expected_activations)
self._calibrate_from_flattened_activations(
np.array(flattened_activations), np.array(flattened_simulated_activations)
)
@abstractmethod
def _calibrate_from_flattened_activations(
self,
true_activations: np.ndarray,
uncalibrated_activations: np.ndarray,
) -> None:
"""
Determine parameters to map from the predicted activation space to the real neuron
activation space, based on a calibration set.
Take numpy arrays of all true activations and all uncalibrated activations on the
calibration set over all sequences.
"""
@abstractmethod
def apply_calibration(self, values: Sequence[float]) -> list[float]:
"""Apply the learned calibration to a sequence of values."""
async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation:
uncalibrated_seq_simulation = await self.uncalibrated_simulator.simulate(tokens)
calibrated_activations = self.apply_calibration(
uncalibrated_seq_simulation.expected_activations
)
calibrated_distribution_values = [
self.apply_calibration(dv) for dv in uncalibrated_seq_simulation.distribution_values
]
return SequenceSimulation(
tokens=uncalibrated_seq_simulation.tokens,
expected_activations=calibrated_activations,
activation_scale=ActivationScale.NEURON_ACTIVATIONS,
distribution_values=calibrated_distribution_values,
distribution_probabilities=uncalibrated_seq_simulation.distribution_probabilities,
uncalibrated_simulation=uncalibrated_seq_simulation,
)
class UncalibratedNeuronSimulator(CalibratedNeuronSimulator):
"""Pass through the activations without trying to calibrate."""
def __init__(self, uncalibrated_simulator: NeuronSimulator):
super().__init__(uncalibrated_simulator)
async def calibrate(self, calibration_activation_records: Sequence[ActivationRecord]) -> None:
pass
def _calibrate_from_flattened_activations(
self,
true_activations: np.ndarray,
uncalibrated_activations: np.ndarray,
) -> None:
pass
def apply_calibration(self, values: Sequence[float]) -> list[float]:
return values if isinstance(values, list) else list(values)
class LinearCalibratedNeuronSimulator(CalibratedNeuronSimulator):
"""Find a linear mapping from uncalibrated activations to true activations.
Should not change ev_correlation_score because it is invariant to linear transformations.
"""
def __init__(self, uncalibrated_simulator: NeuronSimulator):
super().__init__(uncalibrated_simulator)
self._regression: Optional[linear_model.LinearRegression] = None
def _calibrate_from_flattened_activations(
self,
true_activations: np.ndarray,
uncalibrated_activations: np.ndarray,
) -> None:
self._regression = linear_model.LinearRegression()
self._regression.fit(uncalibrated_activations.reshape(-1, 1), true_activations)
def apply_calibration(self, values: Sequence[float]) -> list[float]:
if self._regression is None:
raise ValueError("Must call calibrate() before apply_calibration")
if len(values) == 0:
return []
return self._regression.predict(np.reshape(np.array(values), (-1, 1))).tolist()
class PercentileMatchingCalibratedNeuronSimulator(CalibratedNeuronSimulator):
"""
Map the nth percentile of the uncalibrated activations to the nth percentile of the true
activations for all n.
This will match the distribution of true activations on the calibration set, but will be
overconfident outside of the calibration set.
"""
def __init__(self, uncalibrated_simulator: NeuronSimulator):
super().__init__(uncalibrated_simulator)
self._uncalibrated_activations: Optional[np.ndarray] = None
self._true_activations: Optional[np.ndarray] = None
def _calibrate_from_flattened_activations(
self,
true_activations: np.ndarray,
uncalibrated_activations: np.ndarray,
) -> None:
self._uncalibrated_activations = np.sort(uncalibrated_activations)
self._true_activations = np.sort(true_activations)
def apply_calibration(self, values: Sequence[float]) -> list[float]:
if self._true_activations is None or self._uncalibrated_activations is None:
raise ValueError("Must call calibrate() before apply_calibration")
if len(values) == 0:
return []
return np.interp(
np.array(values), self._uncalibrated_activations, self._true_activations
).tolist()