neuron-explainer/neuron_explainer/explanations/puzzles.py (32 lines of code) (raw):
import json
import os
from dataclasses import dataclass
from neuron_explainer.activations.activations import ActivationRecord
@dataclass(frozen=True)
class Puzzle:
"""A puzzle is a ground truth explanation, a collection of sentences (stored as ActivationRecords) with activations
according to that explanation, and a collection of false explanations"""
name: str
explanation: str
activation_records: list[ActivationRecord]
false_explanations: list[str]
def convert_puzzle_to_tokenized_sentences(puzzle: Puzzle) -> list[list[str]]:
"""Converts a puzzle to a list of tokenized sentences."""
return [record.tokens for record in puzzle.activation_records]
def convert_puzzle_dict_to_puzzle(puzzle_dict: dict) -> Puzzle:
"""Converts a json dictionary representation of a puzzle to the Puzzle class."""
puzzle_activation_records = []
for sentence in puzzle_dict["sentences"]:
# Token-activation pairs are listed as either a string or a list of a string and a float. If it is a list, the float is the activation.
# If it is only a string, the activation is assumed to be 0. This is useful for readability and reducing redundancy in the data.
tokens = [t[0] if type(t) is list else t for t in sentence]
assert all([type(t) is str for t in tokens]), "All tokens must be strings"
activations = [float(t[1]) if type(t) is list else 0.0 for t in sentence]
assert all([type(t) is float for t in activations]), "All activations must be floats"
puzzle_activation_records.append(ActivationRecord(tokens=tokens, activations=activations))
return Puzzle(
name=puzzle_dict["name"],
explanation=puzzle_dict["explanation"],
activation_records=puzzle_activation_records,
false_explanations=puzzle_dict["false_explanations"],
)
PUZZLES_BY_NAME: dict[str, Puzzle] = dict()
script_dir = os.path.dirname(os.path.abspath(__file__))
with open(os.path.join(script_dir, "puzzles.json"), "r") as f:
puzzle_dicts = json.loads(f.read())
for name in puzzle_dicts.keys():
PUZZLES_BY_NAME[name] = convert_puzzle_dict_to_puzzle(puzzle_dicts[name])