neuron_explainer/activation_server/explanation_datasets.py (23 lines of code) (raw):
import os
from neuron_explainer.activation_server.load_neurons import convert_dataset_path_to_short_name
# Maps from neuron dataset path to explanation dataset path.
AZURE_EXPLANATION_DATASET_REGISTRY = {
"https://openaipublic.blob.core.windows.net/neuron-explainer/data/collated-activations/": "https://openaipublic.blob.core.windows.net/neuron-explainer/data/explanations/",
"https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/collated-activations/": "https://openaipublic.blob.core.windows.net/neuron-explainer/gpt2_small_data/explanations/",
}
def get_local_cached_explanation_directory(dataset_path: str) -> str:
root_project_directory = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
dataset_short_name = convert_dataset_path_to_short_name(dataset_path)
return f"{root_project_directory}/cached_explanations/{dataset_short_name}"
async def get_all_explanation_datasets(neuron_dataset: str) -> list[str]:
"""
Get all explanation datasets for a given neuron dataset. Search the public azure bucket and also
the local filesystem cache. Returns a list of paths to the explanation datasets.
Path can be an azure path (beginning with `https://`) or a local path.
"""
datasets = []
if neuron_dataset in AZURE_EXPLANATION_DATASET_REGISTRY:
datasets.append(AZURE_EXPLANATION_DATASET_REGISTRY[neuron_dataset])
local_cache_dir = get_local_cached_explanation_directory(neuron_dataset)
# Iterate through folders to get a list of dirs.
# There will be different local cache directories if the user generates scored explanations for
# the same neuron dataset using different neuron/attention explainer registry entries (i.e. so
# that AttentionExplainAndScoreMethodId or NeuronExplainAndScoreMethodId differ).
if os.path.exists(local_cache_dir) and os.path.isdir(local_cache_dir):
for entry in os.listdir(local_cache_dir):
candidate_path = os.path.join(local_cache_dir, entry)
if os.path.isdir(candidate_path):
datasets.append(candidate_path)
return datasets