neuron_viewer/src/TransformerDebugger/TransformerDebugger.tsx (299 lines of code) (raw):
// Transformer Debugger, interpretability tool to allow inspecting model activations
import React, { useState, useMemo, useEffect } from "react";
import {
type MultipleTopKDerivedScalarsResponseData,
type ModelInfoResponse,
InferenceResponseAndResponseDict,
TdbRequestSpec,
} from "../client";
import { useLocation, useNavigate } from "react-router-dom";
import { LogitsDisplay } from "./cards/LogitsDisplay";
import { NodeTable } from "./cards/node_table/NodeTable";
import { InferenceParamsDisplay } from "./cards/inference_params/InferenceParamsDisplay";
import { BySequenceTokenDisplay } from "./cards/BySequenceTokenDisplay";
import { getInferenceAndTokenData, getSubResponse } from "./requests/inferenceResponseUtils";
import { Card, CardBody } from "@nextui-org/react";
import { Link } from "@nextui-org/react";
import { queryToInferenceParams, updateQueryFromInferenceParams } from "./utils/urlParams";
import { useExplanationFetcher } from "./requests/explanationFetcher";
import { InferenceDataFetcher, fetchModelInfo } from "./requests/inferenceDataFetcher";
import DisplayOptions from "./cards/DisplayOptions";
import JsonModal from "./common/JsonModal";
import TokenTable from "./cards/TokenTable";
import {
CommonInferenceParams,
PromptInferenceParams,
} from "./cards/inference_params/inferenceParams";
const TransformerDebugger: React.FC = () => {
// Top level component, should manage all state and pass it down to children
const location = useLocation();
const navigate = useNavigate();
const query = useMemo(() => new URLSearchParams(location.search), [location.search]);
const {
commonParams: commonParamsFromUrl,
leftPromptParams: leftPromptParamsFromUrl,
rightPromptParams: rightPromptParamsFromUrl,
} = queryToInferenceParams(query);
const [commonInferenceParams, setCommonInferenceParams] =
useState<CommonInferenceParams>(commonParamsFromUrl);
const [leftPromptInferenceParams, setLeftPromptInferenceParams] =
useState<PromptInferenceParams | null>(leftPromptParamsFromUrl);
const [rightPromptInferenceParams, setRightPromptInferenceParams] =
useState<PromptInferenceParams | null>(rightPromptParamsFromUrl);
const [twoPromptsMode, setTwoPromptsMode] = React.useState(rightPromptInferenceParams !== null);
const [modelInfo, setModelInfo] = useState<ModelInfoResponse | null>(null);
if (!leftPromptInferenceParams) {
throw new Error("leftPromptInferenceParams should never be null");
}
const { explanationMap, setNodesRequestingExplanation } = useExplanationFetcher();
useEffect(() => {
const updatedQuery = updateQueryFromInferenceParams(
query,
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams
);
navigate({ search: updatedQuery.toString() });
}, [
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams,
navigate,
query,
]);
// TDB has a concept of left and right requests and responses. In cases where one request is a
// "test" request and the other is a baseline, the left request is the test request and the right
// request is the baseline request. In cases where there's only one request, it's the left
// request. In other cases, left vs. right is arbitrary.
const [rightRequest, setRightRequest] = useState<TdbRequestSpec | null>(null);
const [rightResponse, setRightResponse] = useState<InferenceResponseAndResponseDict | null>(null);
const [leftRequest, setLeftRequest] = useState<TdbRequestSpec | null>(null);
const [leftResponse, setLeftResponse] = useState<InferenceResponseAndResponseDict | null>(null);
const [activationServerErrorMessage, setActivationServerErrorMessage] = useState<string | null>(
null
);
const inferenceDataFetcher = new InferenceDataFetcher();
const fetchInferenceData = React.useCallback(async () => {
inferenceDataFetcher.fetch(
modelInfo,
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams,
setRightResponse,
setLeftResponse,
setRightRequest,
setLeftRequest,
setActivationServerErrorMessage
);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [commonInferenceParams, leftPromptInferenceParams, rightPromptInferenceParams, modelInfo]);
useEffect(() => {
fetchModelInfo(setModelInfo, setActivationServerErrorMessage);
}, []);
const prevCommonInferenceParamsRef = React.useRef<CommonInferenceParams>();
const prevLeftPromptInferenceParamsRef = React.useRef<PromptInferenceParams>();
const prevRightPromptInferenceParamsRef = React.useRef<PromptInferenceParams | null>();
// Call fetchInferenceData once on mount and whenever specific inference parameters change.
useEffect(() => {
const shouldFetch = inferenceDataFetcher.shouldFetch(
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams,
prevCommonInferenceParamsRef,
prevLeftPromptInferenceParamsRef,
prevRightPromptInferenceParamsRef
);
if (shouldFetch) {
fetchInferenceData();
}
prevCommonInferenceParamsRef.current = commonInferenceParams;
prevLeftPromptInferenceParamsRef.current = leftPromptInferenceParams;
prevRightPromptInferenceParamsRef.current = rightPromptInferenceParams;
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams,
twoPromptsMode,
]);
// Fetch inference data any time the modelInfo changes.
useEffect(() => {
fetchInferenceData();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [modelInfo]);
const [displaySettings, setDisplaySettings] = useState(
new Map<string, boolean>([
["logits", true],
["node", true],
["bySequenceToken", false],
])
);
const toggleDisplay = React.useCallback(
(key: string) => {
setDisplaySettings((prevSettings) => {
const newSettings = new Map(prevSettings);
newSettings.set(key, !newSettings.get(key));
return newSettings;
});
},
[setDisplaySettings]
);
const shouldShowBySequenceTokenDisplay = () => {
return displaySettings.get("bySequenceToken") && leftResponse;
};
const prompts = [leftPromptInferenceParams.prompt];
if (rightPromptInferenceParams) {
prompts.push(rightPromptInferenceParams.prompt);
}
return (
<div>
<div className="flex justify-between m-5">
<h1 className="text-2xl font-bold">🔠Transformer Debugger</h1>
<div className="flex space-x-4">
<Link
href="https://github.com/openai/transformer-debugger/blob/main/README.md"
target="_blank"
rel="noopener noreferrer"
>
Introduction
</Link>
<Link
href="https://github.com/openai/transformer-debugger/blob/main/terminology.md"
target="_blank"
rel="noopener noreferrer"
>
Terminology
</Link>
</div>
</div>
<Card>
<CardBody>
<InferenceParamsDisplay
commonInferenceParams={commonInferenceParams}
setCommonInferenceParams={setCommonInferenceParams}
leftPromptInferenceParams={leftPromptInferenceParams}
setLeftPromptInferenceParams={setLeftPromptInferenceParams}
rightPromptInferenceParams={rightPromptInferenceParams}
setRightPromptInferenceParams={setRightPromptInferenceParams}
twoPromptsMode={twoPromptsMode}
setTwoPromptsMode={setTwoPromptsMode}
modelInfo={modelInfo}
fetchInferenceData={fetchInferenceData}
inferenceAndTokenData={getInferenceAndTokenData(leftResponse)}
/>
</CardBody>
</Card>
<Card>
<CardBody>
<div className="flex mt-2 space-x-4">
<DisplayOptions displaySettings={displaySettings} toggleDisplay={toggleDisplay} />
<JsonModal
jsonData={{
commonInferenceParams,
leftPromptInferenceParams,
rightPromptInferenceParams,
leftRequest,
rightRequest,
leftResponse,
rightResponse,
}}
/>
</div>
</CardBody>
</Card>
{activationServerErrorMessage && (
<Card>
<CardBody>
<div className="text-red-500">Error: {activationServerErrorMessage}</div>
</CardBody>
</Card>
)}
{displaySettings.get("logits") &&
getSubResponse<MultipleTopKDerivedScalarsResponseData>(
leftResponse,
"topOutputTokenLogits"
) &&
getInferenceAndTokenData(leftResponse) && (
<Card>
<CardBody>
<LogitsDisplay
leftPromptInferenceParams={leftPromptInferenceParams}
rightPromptInferenceParams={rightPromptInferenceParams}
rightResponseData={getSubResponse<MultipleTopKDerivedScalarsResponseData>(
rightResponse,
"topOutputTokenLogits"
)}
leftResponseData={
getSubResponse<MultipleTopKDerivedScalarsResponseData>(
leftResponse,
"topOutputTokenLogits"
)!
}
rightInferenceAndTokenData={getInferenceAndTokenData(rightResponse)}
leftInferenceAndTokenData={getInferenceAndTokenData(leftResponse)!}
/>
</CardBody>
</Card>
)}
{shouldShowBySequenceTokenDisplay() && (
<div className="flex">
{leftResponse && (
<div className="w-1/2">
<Card>
<CardBody>
<BySequenceTokenDisplay
responseData={
getSubResponse<MultipleTopKDerivedScalarsResponseData>(
leftResponse,
"componentSumsForTokenDisplay"
)!
}
inferenceAndTokenData={getInferenceAndTokenData(leftResponse)!}
/>
</CardBody>
</Card>
</div>
)}
{rightResponse && (
<div className="w-1/2">
<Card>
<CardBody>
<BySequenceTokenDisplay
responseData={
getSubResponse<MultipleTopKDerivedScalarsResponseData>(
rightResponse,
"componentSumsForTokenDisplay"
)!
}
inferenceAndTokenData={getInferenceAndTokenData(rightResponse)!}
/>
</CardBody>
</Card>
</div>
)}
</div>
)}
{leftResponse && (
<div className="w-full flex justify-start">
<Card>
<CardBody>
<TokenTable
leftTokens={getInferenceAndTokenData(leftResponse)!.tokensAsStrings}
rightTokens={getInferenceAndTokenData(rightResponse)?.tokensAsStrings}
/>
</CardBody>
</Card>
</div>
)}
{displaySettings.get("node") &&
getSubResponse<MultipleTopKDerivedScalarsResponseData>(leftResponse, "topKComponents") &&
getInferenceAndTokenData(leftResponse) && (
<Card>
<CardBody>
<NodeTable
leftResponse={leftResponse}
rightResponse={rightResponse}
explanationMap={explanationMap}
setNodesRequestingExplanation={setNodesRequestingExplanation}
leftPromptInferenceParams={leftPromptInferenceParams}
setLeftPromptInferenceParams={setLeftPromptInferenceParams}
prompts={prompts}
commonInferenceParams={commonInferenceParams}
setCommonInferenceParams={setCommonInferenceParams}
/>
</CardBody>
</Card>
)}
</div>
);
};
export default TransformerDebugger;