in neuron_viewer/src/TransformerDebugger/requests/inferenceDataFetcher.ts [141:219]
async fetch(
modelInfo: ModelInfoResponse | null,
commonInferenceParams: CommonInferenceParams,
leftPromptInferenceParams: PromptInferenceParams,
rightPromptInferenceParams: PromptInferenceParams | null,
setRightResponse: React.Dispatch<React.SetStateAction<InferenceResponseAndResponseDict | null>>,
setLeftResponse: React.Dispatch<React.SetStateAction<InferenceResponseAndResponseDict | null>>,
setRightRequest: React.Dispatch<React.SetStateAction<TdbRequestSpec | null>>,
setLeftRequest: React.Dispatch<React.SetStateAction<TdbRequestSpec | null>>,
setActivationServerErrorMessage: React.Dispatch<React.SetStateAction<string | null>>
) {
if (modelInfo === null) {
return;
}
setRightResponse(null);
setLeftResponse(null);
const handleInferenceError = this.handleInferenceError;
function performInference(
subRequests: TdbRequestSpec[],
setResponseFns: React.Dispatch<
React.SetStateAction<InferenceResponseAndResponseDict | null>
>[]
) {
batchedTdb({ subRequests })
.then((responseData) => {
if (responseData.inferenceSubResponses.length !== subRequests.length) {
throw new Error(
"Expected exactly " +
subRequests.length +
" inferenceSubResponses, but got " +
responseData.inferenceSubResponses.length
);
}
for (let i = 0; i < responseData.inferenceSubResponses.length; i++) {
setResponseFns[i](responseData.inferenceSubResponses[i]);
}
setActivationServerErrorMessage(null);
})
.catch((error) => handleInferenceError(error, setActivationServerErrorMessage));
}
const newLeftRequest = buildTdbRequestSpec(leftPromptInferenceParams, commonInferenceParams);
setLeftRequest(newLeftRequest);
var newRightRequest = null;
if (rightPromptInferenceParams !== null) {
// We're comparing two prompts. The right request covers the right prompt and both prompts use
// the same ablations (which are stored on the left).
newRightRequest = buildTdbRequestSpec(
{
...rightPromptInferenceParams,
nodeAblations: leftPromptInferenceParams.nodeAblations,
upstreamNodeToTrace: leftPromptInferenceParams.upstreamNodeToTrace,
},
commonInferenceParams
);
} else if (leftPromptInferenceParams.nodeAblations.length !== 0) {
newRightRequest = buildTdbRequestSpec(
{
...leftPromptInferenceParams,
// The right request omits the ablations specified in the left request.
nodeAblations: Array<NodeAblation>(),
},
commonInferenceParams
);
} else {
// If there is no right prompt and no ablations, there is no need to make a separate right
// request.
newRightRequest = null;
}
setRightRequest(newRightRequest);
const subRequests = [newLeftRequest];
const setResponseFns = [setLeftResponse];
if (newRightRequest !== null) {
subRequests.push(newRightRequest);
setResponseFns.push(setRightResponse);
}
performInference(subRequests, setResponseFns);
}