neuron_explainer/activation_server/load_neurons.py (50 lines of code) (raw):
from fastapi import HTTPException
from neuron_explainer.activation_server.neuron_datasets import (
NEURON_DATASET_METADATA_REGISTRY,
get_neuron_dataset_metadata_by_short_name_and_dst,
)
from neuron_explainer.activations.activations import NeuronRecord, load_neuron_async
from neuron_explainer.activations.derived_scalars import DerivedScalarType
from neuron_explainer.pydantic import CamelCaseBaseModel, immutable
@immutable
class NodeIdAndDatasets(CamelCaseBaseModel):
dst: DerivedScalarType
layer_index: int
activation_index: int
datasets: list[str]
"""A list of dataset paths or short names."""
def resolve_neuron_dataset(dataset: str, dst: DerivedScalarType) -> str:
if dataset.startswith("https://"):
return dataset
else:
# It's the short name for a dataset, like "gpt2-small". We have to look up the metadata.
dataset_metadata = get_neuron_dataset_metadata_by_short_name_and_dst(dataset, dst)
return dataset_metadata.neuron_dataset_path
def convert_dataset_path_to_short_name(dataset_path: str) -> str:
assert dataset_path.startswith("https://")
short_name = None
for metadata in NEURON_DATASET_METADATA_REGISTRY.values():
if metadata.neuron_dataset_path == dataset_path:
short_name = metadata.short_name
break
assert (
short_name is not None
), f"Could not find short name for {dataset_path}. If you're trying to use a custom dataset, ensure that you have added it to neuron_datasets.py:NEURON_DATASET_METADATA_REGISTRY."
return short_name
async def load_neuron_from_datasets(
node_id_and_datasets: NodeIdAndDatasets,
) -> tuple[str, NeuronRecord]:
"""
Load a neuron record of the specified dst (e.g. DerivedScalarType.MLP_POST_ACT) from a list of
datasets, returning the data from the first dataset that has the neuron.
Used to allow first trying a dataset that only covers a subset of neurons for a model,
with a fallback to another dataset that covers all neurons.
"""
dst = node_id_and_datasets.dst
datasets = node_id_and_datasets.datasets
dataset_paths = [resolve_neuron_dataset(dataset, dst) for dataset in datasets]
layer_index = node_id_and_datasets.layer_index
activation_index = node_id_and_datasets.activation_index
for dataset_path in dataset_paths:
try:
return dataset_path, await load_neuron_async(
dataset_path, layer_index, activation_index
)
except FileNotFoundError:
pass
raise HTTPException(
status_code=404,
detail=f"Could not find {dst} {layer_index}:{activation_index} in {dataset_paths}",
)