neuron_viewer/src/nodePage.tsx (185 lines of code) (raw):
import React, { useState, useMemo, useCallback, useEffect } from "react";
import { useParams, useLocation } from "react-router-dom";
import ModelInteractions from "./modelInteractions";
import { PaneComponents, PaneComponentType } from "./panes";
import {
DimensionalityOfActivations,
Node,
PROMPTS_SEPARATOR,
PaneData,
getDimensionalityOfActivations,
stringToNodeType,
} from "./types";
import { SectionTitle } from "./commonUiComponents";
import Navigation from "./navigation";
import { AttentionHeadRecordResponse, NeuronRecordResponse } from "./client";
import HeatmapGrid, { HeatmapGridProps } from "./heatmapGrid";
import HeatmapGrid2d, { HeatmapGrid2dProps } from "./heatmapGrid2d";
import { readAttentionHeadRecord, readNeuronRecord } from "./requests/readRequests";
// Math.random returns a value on [0, 1). Start each substring at index 2 to drop the "0." prefix.
const uuid = () =>
Math.random().toString(36).substring(2) + Math.random().toString(36).substring(2);
const Pane: React.FC<{
children: React.ReactNode;
id?: string;
}> = ({ children, id }) => (
<div id={id}>
<div className="flex flex-col h-full">{children}</div>
</div>
);
type PaneItem = { type: PaneComponentType; [key: string]: any };
function getHeatmapComponent(
dimensionalityOfActivations: DimensionalityOfActivations
): React.FC<HeatmapGridProps> | React.FC<HeatmapGrid2dProps> {
switch (dimensionalityOfActivations) {
case DimensionalityOfActivations.SCALAR_PER_TOKEN:
return HeatmapGrid;
case DimensionalityOfActivations.SCALAR_PER_TOKEN_PAIR:
return HeatmapGrid2d;
default:
throw new Error("Unknown dimensionality of activations: " + dimensionalityOfActivations);
}
}
const NodePage = () => {
const [additionalPanes, setAdditionalPanes] = useState<PaneData[]>([]);
const [record, setRecord] = useState<NeuronRecordResponse | AttentionHeadRecordResponse | null>(
null
);
const [isLoadingRecord, setIsLoadingRecord] = useState(true);
const [errorMessage, setErrorMessage] = useState<string | null>(null);
let params =
useParams<{ model: string; nodeTypeStr: string; layerIndex: string; nodeIndex: string }>();
const activeNode: Node = useMemo(() => {
if (
params.nodeTypeStr === undefined ||
params.layerIndex === undefined ||
params.nodeIndex === undefined
) {
throw new Error("One or more of node type, layer index, and node index are undefined.");
} else {
return {
nodeType: stringToNodeType(params.nodeTypeStr),
layerIndex: parseInt(params.layerIndex),
nodeIndex: parseInt(params.nodeIndex),
};
}
}, [params]);
const dimensionalityOfActivations = getDimensionalityOfActivations(activeNode.nodeType);
// Define this constant to make the checks below more concise and readable.
const scalarPerToken =
dimensionalityOfActivations === DimensionalityOfActivations.SCALAR_PER_TOKEN;
const fetchRecord = useCallback(async () => {
switch (dimensionalityOfActivations) {
case DimensionalityOfActivations.SCALAR_PER_TOKEN:
return await readNeuronRecord(activeNode);
case DimensionalityOfActivations.SCALAR_PER_TOKEN_PAIR:
return await readAttentionHeadRecord(activeNode);
default:
throw new Error("Unknown dimensionality of activations: " + dimensionalityOfActivations);
}
}, [activeNode, dimensionalityOfActivations]);
useEffect(() => {
async function fetchData() {
try {
const result = await fetchRecord();
setErrorMessage(null);
setRecord(result);
setIsLoadingRecord(false);
} catch (error) {
if (error instanceof Error) {
setErrorMessage(error.message);
} else {
setErrorMessage("Unknown error");
}
}
}
fetchData();
}, [fetchRecord]);
const addPane = useCallback((items: PaneItem[]) => {
const newPanes = items.map((item) => {
return {
id: uuid().toString(),
...item,
};
});
setAdditionalPanes((panes) => [...panes, ...newPanes]);
}, []);
const location = useLocation();
const urlSearchParams = useMemo(() => new URLSearchParams(location.search), [location.search]);
const promptsParam = urlSearchParams.get("promptsOfInterest");
var promptsOfInterest: string[];
if (promptsParam == null) {
promptsOfInterest = [];
} else {
promptsOfInterest = promptsParam.split(PROMPTS_SEPARATOR);
}
return (
<div>
<Navigation activeNode={activeNode} />
<div className="flow-root" style={{ width: 600, margin: "auto", overflow: "visible" }}>
<ul className="mb-8 mt-10">
{promptsOfInterest?.length ? (
<>
<SectionTitle>Prompts of interest</SectionTitle>
<Pane>
{promptsOfInterest.map((prompt, index) => (
<Pane key={index}>
<PaneComponents.ActivationsForPrompt
activeNode={activeNode}
sentence={prompt}
/>
</Pane>
))}
</Pane>
<div style={{ marginBottom: "15px" }}> </div>
</>
) : null}
{
<Pane>
<PaneComponents.Explanation activeNode={activeNode} />
</Pane>
}
<Pane>
<PaneComponents.DatasetExamples
record={record}
isLoadingRecord={isLoadingRecord}
errorMessage={errorMessage}
heatmapComponent={getHeatmapComponent(dimensionalityOfActivations)}
/>
</Pane>
<Pane>
<PaneComponents.LogitLens activeNode={activeNode} />
</Pane>
{additionalPanes.map((pane) => (
<li key={pane.id}>
<div className="relative">
<div className="relative flex items-start space-x-3">
<Pane key={pane.id} id={pane.id} {...pane}>
{React.createElement(PaneComponents[pane.type] as any, {
...pane,
key: pane.id,
activeNode: activeNode!,
})}
</Pane>
</div>
</div>
</li>
))}
</ul>
<div>
<ModelInteractions
onGetActivationsForPrompt={(sentence) => {
addPane([{ type: "ActivationsForPrompt", sentence }]);
}}
// Scoring explanations is only supported for neurons and autoencoder latents.
onScoreExplanation={
scalarPerToken
? (explanation) => {
addPane([{ type: "ScoreExplanation", explanation }]);
}
: undefined
}
/>
</div>
</div>
</div>
);
};
export default NodePage;